diff --git a/.binder/environment.yml b/.binder/environment.yml index 13b6b99e6fc..6fd5829c5e6 100644 --- a/.binder/environment.yml +++ b/.binder/environment.yml @@ -2,7 +2,7 @@ name: xarray-examples channels: - conda-forge dependencies: - - python=3.7 + - python=3.8 - boto3 - bottleneck - cartopy @@ -31,6 +31,7 @@ dependencies: - rasterio - scipy - seaborn + - setuptools - sparse - toolz - xarray diff --git a/.coveragerc b/.coveragerc index 2bf6e08f5af..1bf19c310aa 100644 --- a/.coveragerc +++ b/.coveragerc @@ -5,4 +5,3 @@ omit = xarray/core/npcompat.py xarray/core/pdcompat.py xarray/core/pycompat.py - xarray/_version.py diff --git a/.gitattributes b/.gitattributes index daa5b82874e..a52f4ca283a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,3 +1,2 @@ # reduce the number of merge conflicts doc/whats-new.rst merge=union -xarray/_version.py export-subst diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index f24884c617a..31fef19b32a 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -24,6 +24,6 @@ assignees: '' #### Output of ``xr.show_versions()``
-# Paste the output here xr.show_versions() here +
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index d78dd38dd85..a921bddaa23 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -2,5 +2,5 @@ - [ ] Closes #xxxx - [ ] Tests added - - [ ] Passes `black . && mypy . && flake8` + - [ ] Passes `isort -rc . && black . && mypy . && flake8` - [ ] Fully documented, including `whats-new.rst` for all changes and `api.rst` for new API diff --git a/.gitignore b/.gitignore index ad26864221e..5f02700de37 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ doc/savefig # Packages *.egg *.egg-info +.eggs dist build eggs @@ -65,7 +66,6 @@ dask-worker-space/ # xarray specific doc/_build doc/generated -xarray/version.py xarray/tests/data/*.grib.*.idx # Sync tools diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 502120cd5dc..9df95648774 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,20 +1,24 @@ # https://pre-commit.com/ -# https://github.com/python/black#version-control-integration repos: + # isort should run before black as black sometimes tweaks the isort output + - repo: https://github.com/timothycrosley/isort + rev: 4.3.21-2 + hooks: + - id: isort + # https://github.com/python/black#version-control-integration - repo: https://github.com/python/black rev: stable hooks: - id: black - language_version: python3.7 - repo: https://github.com/pre-commit/pre-commit-hooks rev: v2.2.3 hooks: - id: flake8 - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.730 # Must match ci/requirements/*.yml + rev: v0.761 # Must match ci/requirements/*.yml hooks: - id: mypy - # run these occasionally, ref discussion https://github.com/pydata/xarray/pull/3194 + # run this occasionally, ref discussion https://github.com/pydata/xarray/pull/3194 # - repo: https://github.com/asottile/pyupgrade # rev: v1.22.1 # hooks: @@ -23,7 +27,3 @@ repos: # - "--py3-only" # # remove on f-strings in Py3.7 # - "--keep-percent-format" - # - repo: https://github.com/timothycrosley/isort - # rev: 4.3.21-2 - # hooks: - # - id: isort diff --git a/MANIFEST.in b/MANIFEST.in index 4d5c34f622c..cbfb8c8cdca 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,6 +4,4 @@ recursive-include doc * prune doc/_build prune doc/generated global-exclude .DS_Store -include versioneer.py -include xarray/_version.py recursive-include xarray/static * diff --git a/asv_bench/asv.conf.json b/asv_bench/asv.conf.json index 11a779ae376..d35a2a223a2 100644 --- a/asv_bench/asv.conf.json +++ b/asv_bench/asv.conf.json @@ -40,7 +40,7 @@ // The Pythons you'd like to test against. If not provided, defaults // to the current version of Python used to run `asv`. - "pythons": ["3.6"], + "pythons": ["3.8"], // The matrix of dependencies to test. Each key is the name of a // package (in PyPI) and the values are version numbers. An empty diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 90de0705a27..5789161c966 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -20,11 +20,11 @@ jobs: conda_env: py37 py38: conda_env: py38 - py37-upstream-dev: - conda_env: py37 + py38-upstream-dev: + conda_env: py38 upstream_dev: true - py36-flaky: - conda_env: py36 + py38-flaky: + conda_env: py38 pytest_extra_flags: --run-flaky --run-network-tests allow_failure: true pool: @@ -35,8 +35,8 @@ jobs: - job: MacOSX strategy: matrix: - py36: - conda_env: py36 + py38: + conda_env: py38 pool: vmImage: 'macOS-10.13' steps: @@ -74,7 +74,7 @@ jobs: - job: TypeChecking variables: - conda_env: py37 + conda_env: py38 pool: vmImage: 'ubuntu-16.04' steps: @@ -84,6 +84,18 @@ jobs: mypy . displayName: mypy type checks +- job: isort + variables: + conda_env: py38 + pool: + vmImage: 'ubuntu-16.04' + steps: + - template: ci/azure/install.yml + - bash: | + source activate xarray-tests + isort -rc --check . + displayName: isort formatting checks + - job: MinimumVersionsPolicy pool: vmImage: 'ubuntu-16.04' @@ -110,5 +122,5 @@ jobs: - bash: | source activate xarray-tests cd doc - sphinx-build -n -j auto -b html -d _build/doctrees . _build/html + sphinx-build -W --keep-going -j auto -b html -d _build/doctrees . _build/html displayName: Build HTML docs diff --git a/ci/azure/install.yml b/ci/azure/install.yml index e4f3a0b9e16..958e3c180fa 100644 --- a/ci/azure/install.yml +++ b/ci/azure/install.yml @@ -6,12 +6,14 @@ steps: - template: add-conda-to-path.yml - bash: | + conda update -y conda conda env create -n xarray-tests --file ${{ parameters.env_file }} displayName: Install conda dependencies - bash: | source activate xarray-tests - pip install -f https://7933911d6844c6c53a7d-47bd50c35cd79bd838daf386af554a83.ssl.cf2.rackcdn.com \ + python -m pip install \ + -f https://7933911d6844c6c53a7d-47bd50c35cd79bd838daf386af554a83.ssl.cf2.rackcdn.com \ --no-deps \ --pre \ --upgrade \ @@ -19,7 +21,7 @@ steps: numpy \ pandas \ scipy - pip install \ + python -m pip install \ --no-deps \ --upgrade \ git+https://github.com/dask/dask \ @@ -27,13 +29,14 @@ steps: git+https://github.com/zarr-developers/zarr \ git+https://github.com/Unidata/cftime \ git+https://github.com/mapbox/rasterio \ + git+https://github.com/hgrecco/pint \ git+https://github.com/pydata/bottleneck condition: eq(variables['UPSTREAM_DEV'], 'true') displayName: Install upstream dev dependencies - bash: | source activate xarray-tests - pip install --no-deps -e . + python -m pip install --no-deps -e . displayName: Install xarray - bash: | diff --git a/ci/min_deps_check.py b/ci/min_deps_check.py index a5ba90679b7..527093cf5bc 100755 --- a/ci/min_deps_check.py +++ b/ci/min_deps_check.py @@ -15,6 +15,7 @@ "coveralls", "flake8", "hypothesis", + "isort", "mypy", "pip", "pytest", diff --git a/ci/requirements/doc.yml b/ci/requirements/doc.yml index a0c27a30b01..2c44e754cc4 100644 --- a/ci/requirements/doc.yml +++ b/ci/requirements/doc.yml @@ -3,7 +3,7 @@ channels: # Don't change to pkgs/main, as it causes random timeouts in readthedocs - conda-forge dependencies: - - python=3.7 + - python=3.8 - bottleneck - cartopy - cfgrib @@ -14,11 +14,13 @@ dependencies: - jupyter_client - nbsphinx - netcdf4 + - numba - numpy - numpydoc - - pandas<0.25 # Hack around https://github.com/pydata/xarray/issues/3369 + - pandas - rasterio - seaborn + - setuptools - sphinx - sphinx_rtd_theme - zarr diff --git a/ci/requirements/py36-bare-minimum.yml b/ci/requirements/py36-bare-minimum.yml index 05186bc8748..00fef672855 100644 --- a/ci/requirements/py36-bare-minimum.yml +++ b/ci/requirements/py36-bare-minimum.yml @@ -4,8 +4,10 @@ channels: dependencies: - python=3.6 - coveralls + - pip - pytest - pytest-cov - pytest-env - - numpy=1.14 - - pandas=0.24 + - numpy=1.15 + - pandas=0.25 + - setuptools=41.2 diff --git a/ci/requirements/py36-min-all-deps.yml b/ci/requirements/py36-min-all-deps.yml index 3f10a158f91..86540197dcc 100644 --- a/ci/requirements/py36-min-all-deps.yml +++ b/ci/requirements/py36-min-all-deps.yml @@ -15,22 +15,24 @@ dependencies: - cfgrib=0.9 - cftime=1.0 - coveralls - - dask=1.2 - - distributed=1.27 + - dask=2.2 + - distributed=2.2 - flake8 - h5netcdf=0.7 - h5py=2.9 # Policy allows for 2.10, but it's a conflict-fest - hdf5=1.10 - hypothesis - iris=2.2 + - isort - lxml=4.4 # Optional dep of pydap - matplotlib=3.1 - - mypy=0.730 # Must match .pre-commit-config.yaml + - msgpack-python=0.6 # remove once distributed is bumped. distributed GH3491 + - mypy=0.761 # Must match .pre-commit-config.yaml - nc-time-axis=1.2 - netcdf4=1.4 - numba=0.44 - - numpy=1.14 - - pandas=0.24 + - numpy=1.15 + - pandas=0.25 # - pint # See py36-min-nep18.yml - pip - pseudonetcdf=3.0 @@ -40,8 +42,9 @@ dependencies: - pytest-cov - pytest-env - rasterio=1.0 - - scipy=1.0 # Policy allows for 1.2, but scipy>=1.1 breaks numpy=1.14 + - scipy=1.3 - seaborn=0.9 + - setuptools=41.2 # - sparse # See py36-min-nep18.yml - toolz=0.10 - zarr=2.3 diff --git a/ci/requirements/py36-min-nep18.yml b/ci/requirements/py36-min-nep18.yml index fc9523ce249..c10fdf67dc4 100644 --- a/ci/requirements/py36-min-nep18.yml +++ b/ci/requirements/py36-min-nep18.yml @@ -8,11 +8,14 @@ dependencies: - coveralls - dask=2.4 - distributed=2.4 + - msgpack-python=0.6 # remove once distributed is bumped. distributed GH3491 - numpy=1.17 - - pandas=0.24 + - pandas=0.25 - pint=0.9 # Actually not enough as it doesn't implement __array_function__yet! + - pip - pytest - pytest-cov - pytest-env - scipy=1.2 + - setuptools=41.2 - sparse=0.8 diff --git a/ci/requirements/py36.yml b/ci/requirements/py36.yml index 820160b19cc..a500173f277 100644 --- a/ci/requirements/py36.yml +++ b/ci/requirements/py36.yml @@ -19,9 +19,10 @@ dependencies: - hdf5 - hypothesis - iris - - lxml # optional dep of pydap + - isort + - lxml # Optional dep of pydap - matplotlib - - mypy=0.730 # Must match .pre-commit-config.yaml + - mypy=0.761 # Must match .pre-commit-config.yaml - nc-time-axis - netcdf4 - numba @@ -38,6 +39,7 @@ dependencies: - rasterio - scipy - seaborn + - setuptools - sparse - toolz - zarr diff --git a/ci/requirements/py37-windows.yml b/ci/requirements/py37-windows.yml index 614a3bb1fab..e9e5c7a900a 100644 --- a/ci/requirements/py37-windows.yml +++ b/ci/requirements/py37-windows.yml @@ -19,13 +19,14 @@ dependencies: - hdf5 - hypothesis - iris + - isort - lxml # Optional dep of pydap - matplotlib - - mypy=0.730 # Must match .pre-commit-config.yaml + - mypy=0.761 # Must match .pre-commit-config.yaml - nc-time-axis - netcdf4 - numba - - numpy<1.18 # FIXME https://github.com/pydata/xarray/issues/3409 + - numpy - pandas - pint - pip @@ -38,6 +39,7 @@ dependencies: - rasterio - scipy - seaborn + - setuptools - sparse - toolz - zarr diff --git a/ci/requirements/py37.yml b/ci/requirements/py37.yml index 4a7aaf7d32b..dba3926596e 100644 --- a/ci/requirements/py37.yml +++ b/ci/requirements/py37.yml @@ -19,9 +19,10 @@ dependencies: - hdf5 - hypothesis - iris + - isort - lxml # Optional dep of pydap - matplotlib - - mypy=0.730 # Must match .pre-commit-config.yaml + - mypy=0.761 # Must match .pre-commit-config.yaml - nc-time-axis - netcdf4 - numba @@ -38,6 +39,7 @@ dependencies: - rasterio - scipy - seaborn + - setuptools - sparse - toolz - zarr diff --git a/ci/requirements/py38.yml b/ci/requirements/py38.yml index 9698e3efecf..24602f884e9 100644 --- a/ci/requirements/py38.yml +++ b/ci/requirements/py38.yml @@ -3,13 +3,45 @@ channels: - conda-forge dependencies: - python=3.8 + - black + - boto3 + - bottleneck + - cartopy + - cdms2 + - cfgrib + - cftime + - coveralls + - dask + - distributed + - flake8 + - h5netcdf + - h5py + - hdf5 + - hypothesis + - iris + - isort + - lxml # Optional dep of pydap + - matplotlib + - mypy=0.761 # Must match .pre-commit-config.yaml + - nc-time-axis + - netcdf4 + - numba + - numpy + - pandas + - pint - pip + - pseudonetcdf + - pydap + - pynio + - pytest + - pytest-cov + - pytest-env + - rasterio + - scipy + - seaborn + - setuptools + - sparse + - toolz + - zarr - pip: - - coveralls - - dask - - distributed - - numpy - - pandas - - pytest - - pytest-cov - - pytest-env + - numbagg diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index 027c732697f..437f53b1a91 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -27,6 +27,38 @@ Dataset.std Dataset.var + core.coordinates.DatasetCoordinates.get + core.coordinates.DatasetCoordinates.items + core.coordinates.DatasetCoordinates.keys + core.coordinates.DatasetCoordinates.merge + core.coordinates.DatasetCoordinates.to_dataset + core.coordinates.DatasetCoordinates.to_index + core.coordinates.DatasetCoordinates.update + core.coordinates.DatasetCoordinates.values + core.coordinates.DatasetCoordinates.dims + core.coordinates.DatasetCoordinates.indexes + core.coordinates.DatasetCoordinates.variables + + core.rolling.DatasetCoarsen.all + core.rolling.DatasetCoarsen.any + core.rolling.DatasetCoarsen.argmax + core.rolling.DatasetCoarsen.argmin + core.rolling.DatasetCoarsen.count + core.rolling.DatasetCoarsen.max + core.rolling.DatasetCoarsen.mean + core.rolling.DatasetCoarsen.median + core.rolling.DatasetCoarsen.min + core.rolling.DatasetCoarsen.prod + core.rolling.DatasetCoarsen.std + core.rolling.DatasetCoarsen.sum + core.rolling.DatasetCoarsen.var + core.rolling.DatasetCoarsen.boundary + core.rolling.DatasetCoarsen.coord_func + core.rolling.DatasetCoarsen.obj + core.rolling.DatasetCoarsen.side + core.rolling.DatasetCoarsen.trim_excess + core.rolling.DatasetCoarsen.windows + core.groupby.DatasetGroupBy.assign core.groupby.DatasetGroupBy.assign_coords core.groupby.DatasetGroupBy.first @@ -34,6 +66,69 @@ core.groupby.DatasetGroupBy.fillna core.groupby.DatasetGroupBy.quantile core.groupby.DatasetGroupBy.where + core.groupby.DatasetGroupBy.all + core.groupby.DatasetGroupBy.any + core.groupby.DatasetGroupBy.argmax + core.groupby.DatasetGroupBy.argmin + core.groupby.DatasetGroupBy.count + core.groupby.DatasetGroupBy.max + core.groupby.DatasetGroupBy.mean + core.groupby.DatasetGroupBy.median + core.groupby.DatasetGroupBy.min + core.groupby.DatasetGroupBy.prod + core.groupby.DatasetGroupBy.std + core.groupby.DatasetGroupBy.sum + core.groupby.DatasetGroupBy.var + core.groupby.DatasetGroupBy.dims + core.groupby.DatasetGroupBy.groups + + core.resample.DatasetResample.all + core.resample.DatasetResample.any + core.resample.DatasetResample.apply + core.resample.DatasetResample.argmax + core.resample.DatasetResample.argmin + core.resample.DatasetResample.assign + core.resample.DatasetResample.assign_coords + core.resample.DatasetResample.bfill + core.resample.DatasetResample.count + core.resample.DatasetResample.ffill + core.resample.DatasetResample.fillna + core.resample.DatasetResample.first + core.resample.DatasetResample.last + core.resample.DatasetResample.map + core.resample.DatasetResample.max + core.resample.DatasetResample.mean + core.resample.DatasetResample.median + core.resample.DatasetResample.min + core.resample.DatasetResample.prod + core.resample.DatasetResample.quantile + core.resample.DatasetResample.reduce + core.resample.DatasetResample.std + core.resample.DatasetResample.sum + core.resample.DatasetResample.var + core.resample.DatasetResample.where + core.resample.DatasetResample.dims + core.resample.DatasetResample.groups + + core.rolling.DatasetRolling.argmax + core.rolling.DatasetRolling.argmin + core.rolling.DatasetRolling.count + core.rolling.DatasetRolling.max + core.rolling.DatasetRolling.mean + core.rolling.DatasetRolling.median + core.rolling.DatasetRolling.min + core.rolling.DatasetRolling.prod + core.rolling.DatasetRolling.std + core.rolling.DatasetRolling.sum + core.rolling.DatasetRolling.var + core.rolling.DatasetRolling.center + core.rolling.DatasetRolling.dim + core.rolling.DatasetRolling.min_periods + core.rolling.DatasetRolling.obj + core.rolling.DatasetRolling.rollings + core.rolling.DatasetRolling.window + + core.rolling_exp.RollingExp.mean Dataset.argsort Dataset.astype @@ -47,6 +142,9 @@ Dataset.cumprod Dataset.rank + Dataset.load_store + Dataset.dump_to_store + DataArray.ndim DataArray.nbytes DataArray.shape @@ -71,12 +169,104 @@ DataArray.std DataArray.var + core.coordinates.DataArrayCoordinates.get + core.coordinates.DataArrayCoordinates.items + core.coordinates.DataArrayCoordinates.keys + core.coordinates.DataArrayCoordinates.merge + core.coordinates.DataArrayCoordinates.to_dataset + core.coordinates.DataArrayCoordinates.to_index + core.coordinates.DataArrayCoordinates.update + core.coordinates.DataArrayCoordinates.values + core.coordinates.DataArrayCoordinates.dims + core.coordinates.DataArrayCoordinates.indexes + core.coordinates.DataArrayCoordinates.variables + + core.rolling.DataArrayCoarsen.all + core.rolling.DataArrayCoarsen.any + core.rolling.DataArrayCoarsen.argmax + core.rolling.DataArrayCoarsen.argmin + core.rolling.DataArrayCoarsen.count + core.rolling.DataArrayCoarsen.max + core.rolling.DataArrayCoarsen.mean + core.rolling.DataArrayCoarsen.median + core.rolling.DataArrayCoarsen.min + core.rolling.DataArrayCoarsen.prod + core.rolling.DataArrayCoarsen.std + core.rolling.DataArrayCoarsen.sum + core.rolling.DataArrayCoarsen.var + core.rolling.DataArrayCoarsen.boundary + core.rolling.DataArrayCoarsen.coord_func + core.rolling.DataArrayCoarsen.obj + core.rolling.DataArrayCoarsen.side + core.rolling.DataArrayCoarsen.trim_excess + core.rolling.DataArrayCoarsen.windows + core.groupby.DataArrayGroupBy.assign_coords core.groupby.DataArrayGroupBy.first core.groupby.DataArrayGroupBy.last core.groupby.DataArrayGroupBy.fillna core.groupby.DataArrayGroupBy.quantile core.groupby.DataArrayGroupBy.where + core.groupby.DataArrayGroupBy.all + core.groupby.DataArrayGroupBy.any + core.groupby.DataArrayGroupBy.argmax + core.groupby.DataArrayGroupBy.argmin + core.groupby.DataArrayGroupBy.count + core.groupby.DataArrayGroupBy.max + core.groupby.DataArrayGroupBy.mean + core.groupby.DataArrayGroupBy.median + core.groupby.DataArrayGroupBy.min + core.groupby.DataArrayGroupBy.prod + core.groupby.DataArrayGroupBy.std + core.groupby.DataArrayGroupBy.sum + core.groupby.DataArrayGroupBy.var + core.groupby.DataArrayGroupBy.dims + core.groupby.DataArrayGroupBy.groups + + core.resample.DataArrayResample.all + core.resample.DataArrayResample.any + core.resample.DataArrayResample.apply + core.resample.DataArrayResample.argmax + core.resample.DataArrayResample.argmin + core.resample.DataArrayResample.assign_coords + core.resample.DataArrayResample.bfill + core.resample.DataArrayResample.count + core.resample.DataArrayResample.ffill + core.resample.DataArrayResample.fillna + core.resample.DataArrayResample.first + core.resample.DataArrayResample.last + core.resample.DataArrayResample.map + core.resample.DataArrayResample.max + core.resample.DataArrayResample.mean + core.resample.DataArrayResample.median + core.resample.DataArrayResample.min + core.resample.DataArrayResample.prod + core.resample.DataArrayResample.quantile + core.resample.DataArrayResample.reduce + core.resample.DataArrayResample.std + core.resample.DataArrayResample.sum + core.resample.DataArrayResample.var + core.resample.DataArrayResample.where + core.resample.DataArrayResample.dims + core.resample.DataArrayResample.groups + + core.rolling.DataArrayRolling.argmax + core.rolling.DataArrayRolling.argmin + core.rolling.DataArrayRolling.count + core.rolling.DataArrayRolling.max + core.rolling.DataArrayRolling.mean + core.rolling.DataArrayRolling.median + core.rolling.DataArrayRolling.min + core.rolling.DataArrayRolling.prod + core.rolling.DataArrayRolling.std + core.rolling.DataArrayRolling.sum + core.rolling.DataArrayRolling.var + core.rolling.DataArrayRolling.center + core.rolling.DataArrayRolling.dim + core.rolling.DataArrayRolling.min_periods + core.rolling.DataArrayRolling.obj + core.rolling.DataArrayRolling.window + core.rolling.DataArrayRolling.window_labels DataArray.argsort DataArray.clip @@ -91,6 +281,221 @@ DataArray.cumprod DataArray.rank + core.accessor_dt.DatetimeAccessor.ceil + core.accessor_dt.DatetimeAccessor.floor + core.accessor_dt.DatetimeAccessor.round + core.accessor_dt.DatetimeAccessor.strftime + core.accessor_dt.DatetimeAccessor.day + core.accessor_dt.DatetimeAccessor.dayofweek + core.accessor_dt.DatetimeAccessor.dayofyear + core.accessor_dt.DatetimeAccessor.days_in_month + core.accessor_dt.DatetimeAccessor.daysinmonth + core.accessor_dt.DatetimeAccessor.hour + core.accessor_dt.DatetimeAccessor.microsecond + core.accessor_dt.DatetimeAccessor.minute + core.accessor_dt.DatetimeAccessor.month + core.accessor_dt.DatetimeAccessor.nanosecond + core.accessor_dt.DatetimeAccessor.quarter + core.accessor_dt.DatetimeAccessor.season + core.accessor_dt.DatetimeAccessor.second + core.accessor_dt.DatetimeAccessor.time + core.accessor_dt.DatetimeAccessor.week + core.accessor_dt.DatetimeAccessor.weekday + core.accessor_dt.DatetimeAccessor.weekday_name + core.accessor_dt.DatetimeAccessor.weekofyear + core.accessor_dt.DatetimeAccessor.year + + core.accessor_str.StringAccessor.capitalize + core.accessor_str.StringAccessor.center + core.accessor_str.StringAccessor.contains + core.accessor_str.StringAccessor.count + core.accessor_str.StringAccessor.decode + core.accessor_str.StringAccessor.encode + core.accessor_str.StringAccessor.endswith + core.accessor_str.StringAccessor.find + core.accessor_str.StringAccessor.get + core.accessor_str.StringAccessor.index + core.accessor_str.StringAccessor.isalnum + core.accessor_str.StringAccessor.isalpha + core.accessor_str.StringAccessor.isdecimal + core.accessor_str.StringAccessor.isdigit + core.accessor_str.StringAccessor.islower + core.accessor_str.StringAccessor.isnumeric + core.accessor_str.StringAccessor.isspace + core.accessor_str.StringAccessor.istitle + core.accessor_str.StringAccessor.isupper + core.accessor_str.StringAccessor.len + core.accessor_str.StringAccessor.ljust + core.accessor_str.StringAccessor.lower + core.accessor_str.StringAccessor.lstrip + core.accessor_str.StringAccessor.match + core.accessor_str.StringAccessor.pad + core.accessor_str.StringAccessor.repeat + core.accessor_str.StringAccessor.replace + core.accessor_str.StringAccessor.rfind + core.accessor_str.StringAccessor.rindex + core.accessor_str.StringAccessor.rjust + core.accessor_str.StringAccessor.rstrip + core.accessor_str.StringAccessor.slice + core.accessor_str.StringAccessor.slice_replace + core.accessor_str.StringAccessor.startswith + core.accessor_str.StringAccessor.strip + core.accessor_str.StringAccessor.swapcase + core.accessor_str.StringAccessor.title + core.accessor_str.StringAccessor.translate + core.accessor_str.StringAccessor.upper + core.accessor_str.StringAccessor.wrap + core.accessor_str.StringAccessor.zfill + + Variable.all + Variable.any + Variable.argmax + Variable.argmin + Variable.argsort + Variable.astype + Variable.broadcast_equals + Variable.chunk + Variable.clip + Variable.coarsen + Variable.compute + Variable.concat + Variable.conj + Variable.conjugate + Variable.copy + Variable.count + Variable.cumprod + Variable.cumsum + Variable.equals + Variable.fillna + Variable.get_axis_num + Variable.identical + Variable.isel + Variable.isnull + Variable.item + Variable.load + Variable.max + Variable.mean + Variable.median + Variable.min + Variable.no_conflicts + Variable.notnull + Variable.pad_with_fill_value + Variable.prod + Variable.quantile + Variable.rank + Variable.reduce + Variable.roll + Variable.rolling_window + Variable.round + Variable.searchsorted + Variable.set_dims + Variable.shift + Variable.squeeze + Variable.stack + Variable.std + Variable.sum + Variable.to_base_variable + Variable.to_coord + Variable.to_dict + Variable.to_index + Variable.to_index_variable + Variable.to_variable + Variable.transpose + Variable.unstack + Variable.var + Variable.where + Variable.T + Variable.attrs + Variable.chunks + Variable.data + Variable.dims + Variable.dtype + Variable.encoding + Variable.imag + Variable.nbytes + Variable.ndim + Variable.real + Variable.shape + Variable.size + Variable.sizes + Variable.values + + IndexVariable.all + IndexVariable.any + IndexVariable.argmax + IndexVariable.argmin + IndexVariable.argsort + IndexVariable.astype + IndexVariable.broadcast_equals + IndexVariable.chunk + IndexVariable.clip + IndexVariable.coarsen + IndexVariable.compute + IndexVariable.concat + IndexVariable.conj + IndexVariable.conjugate + IndexVariable.copy + IndexVariable.count + IndexVariable.cumprod + IndexVariable.cumsum + IndexVariable.equals + IndexVariable.fillna + IndexVariable.get_axis_num + IndexVariable.get_level_variable + IndexVariable.identical + IndexVariable.isel + IndexVariable.isnull + IndexVariable.item + IndexVariable.load + IndexVariable.max + IndexVariable.mean + IndexVariable.median + IndexVariable.min + IndexVariable.no_conflicts + IndexVariable.notnull + IndexVariable.pad_with_fill_value + IndexVariable.prod + IndexVariable.quantile + IndexVariable.rank + IndexVariable.reduce + IndexVariable.roll + IndexVariable.rolling_window + IndexVariable.round + IndexVariable.searchsorted + IndexVariable.set_dims + IndexVariable.shift + IndexVariable.squeeze + IndexVariable.stack + IndexVariable.std + IndexVariable.sum + IndexVariable.to_base_variable + IndexVariable.to_coord + IndexVariable.to_dict + IndexVariable.to_index + IndexVariable.to_index_variable + IndexVariable.to_variable + IndexVariable.transpose + IndexVariable.unstack + IndexVariable.var + IndexVariable.where + IndexVariable.T + IndexVariable.attrs + IndexVariable.chunks + IndexVariable.data + IndexVariable.dims + IndexVariable.dtype + IndexVariable.encoding + IndexVariable.imag + IndexVariable.level_names + IndexVariable.name + IndexVariable.nbytes + IndexVariable.ndim + IndexVariable.real + IndexVariable.shape + IndexVariable.size + IndexVariable.sizes + IndexVariable.values + ufuncs.angle ufuncs.arccos ufuncs.arccosh @@ -156,6 +561,242 @@ plot.FacetGrid.set_ticks plot.FacetGrid.map + CFTimeIndex.all + CFTimeIndex.any + CFTimeIndex.append + CFTimeIndex.argmax + CFTimeIndex.argmin + CFTimeIndex.argsort + CFTimeIndex.asof + CFTimeIndex.asof_locs + CFTimeIndex.astype + CFTimeIndex.contains + CFTimeIndex.copy + CFTimeIndex.delete + CFTimeIndex.difference + CFTimeIndex.drop + CFTimeIndex.drop_duplicates + CFTimeIndex.droplevel + CFTimeIndex.dropna + CFTimeIndex.duplicated + CFTimeIndex.equals + CFTimeIndex.factorize + CFTimeIndex.fillna + CFTimeIndex.format + CFTimeIndex.get_indexer + CFTimeIndex.get_indexer_for + CFTimeIndex.get_indexer_non_unique + CFTimeIndex.get_level_values + CFTimeIndex.get_loc + CFTimeIndex.get_slice_bound + CFTimeIndex.get_value + CFTimeIndex.groupby + CFTimeIndex.holds_integer + CFTimeIndex.identical + CFTimeIndex.insert + CFTimeIndex.intersection + CFTimeIndex.is_ + CFTimeIndex.is_boolean + CFTimeIndex.is_categorical + CFTimeIndex.is_floating + CFTimeIndex.is_integer + CFTimeIndex.is_interval + CFTimeIndex.is_mixed + CFTimeIndex.is_numeric + CFTimeIndex.is_object + CFTimeIndex.is_type_compatible + CFTimeIndex.isin + CFTimeIndex.isna + CFTimeIndex.isnull + CFTimeIndex.item + CFTimeIndex.join + CFTimeIndex.map + CFTimeIndex.max + CFTimeIndex.memory_usage + CFTimeIndex.min + CFTimeIndex.notna + CFTimeIndex.notnull + CFTimeIndex.nunique + CFTimeIndex.putmask + CFTimeIndex.ravel + CFTimeIndex.reindex + CFTimeIndex.rename + CFTimeIndex.repeat + CFTimeIndex.searchsorted + CFTimeIndex.set_names + CFTimeIndex.set_value CFTimeIndex.shift - CFTimeIndex.to_datetimeindex + CFTimeIndex.slice_indexer + CFTimeIndex.slice_locs + CFTimeIndex.sort + CFTimeIndex.sort_values + CFTimeIndex.sortlevel CFTimeIndex.strftime + CFTimeIndex.symmetric_difference + CFTimeIndex.take + CFTimeIndex.to_datetimeindex + CFTimeIndex.to_flat_index + CFTimeIndex.to_frame + CFTimeIndex.to_list + CFTimeIndex.to_native_types + CFTimeIndex.to_numpy + CFTimeIndex.to_series + CFTimeIndex.tolist + CFTimeIndex.transpose + CFTimeIndex.union + CFTimeIndex.unique + CFTimeIndex.value_counts + CFTimeIndex.view + CFTimeIndex.where + + CFTimeIndex.T + CFTimeIndex.array + CFTimeIndex.asi8 + CFTimeIndex.date_type + CFTimeIndex.day + CFTimeIndex.dayofweek + CFTimeIndex.dayofyear + CFTimeIndex.dtype + CFTimeIndex.empty + CFTimeIndex.has_duplicates + CFTimeIndex.hasnans + CFTimeIndex.hour + CFTimeIndex.inferred_type + CFTimeIndex.is_all_dates + CFTimeIndex.is_monotonic + CFTimeIndex.is_monotonic_increasing + CFTimeIndex.is_monotonic_decreasing + CFTimeIndex.is_unique + CFTimeIndex.microsecond + CFTimeIndex.minute + CFTimeIndex.month + CFTimeIndex.name + CFTimeIndex.names + CFTimeIndex.nbytes + CFTimeIndex.ndim + CFTimeIndex.nlevels + CFTimeIndex.second + CFTimeIndex.shape + CFTimeIndex.size + CFTimeIndex.values + CFTimeIndex.year + + backends.NetCDF4DataStore.close + backends.NetCDF4DataStore.encode + backends.NetCDF4DataStore.encode_attribute + backends.NetCDF4DataStore.encode_variable + backends.NetCDF4DataStore.get + backends.NetCDF4DataStore.get_attrs + backends.NetCDF4DataStore.get_dimensions + backends.NetCDF4DataStore.get_encoding + backends.NetCDF4DataStore.get_variables + backends.NetCDF4DataStore.items + backends.NetCDF4DataStore.keys + backends.NetCDF4DataStore.load + backends.NetCDF4DataStore.open + backends.NetCDF4DataStore.open_store_variable + backends.NetCDF4DataStore.prepare_variable + backends.NetCDF4DataStore.set_attribute + backends.NetCDF4DataStore.set_attributes + backends.NetCDF4DataStore.set_dimension + backends.NetCDF4DataStore.set_dimensions + backends.NetCDF4DataStore.set_variable + backends.NetCDF4DataStore.set_variables + backends.NetCDF4DataStore.store + backends.NetCDF4DataStore.store_dataset + backends.NetCDF4DataStore.sync + backends.NetCDF4DataStore.values + backends.NetCDF4DataStore.attrs + backends.NetCDF4DataStore.autoclose + backends.NetCDF4DataStore.dimensions + backends.NetCDF4DataStore.ds + backends.NetCDF4DataStore.format + backends.NetCDF4DataStore.is_remote + backends.NetCDF4DataStore.lock + backends.NetCDF4DataStore.variables + + backends.H5NetCDFStore.close + backends.H5NetCDFStore.encode + backends.H5NetCDFStore.encode_attribute + backends.H5NetCDFStore.encode_variable + backends.H5NetCDFStore.get + backends.H5NetCDFStore.get_attrs + backends.H5NetCDFStore.get_dimensions + backends.H5NetCDFStore.get_encoding + backends.H5NetCDFStore.get_variables + backends.H5NetCDFStore.items + backends.H5NetCDFStore.keys + backends.H5NetCDFStore.load + backends.H5NetCDFStore.open_store_variable + backends.H5NetCDFStore.prepare_variable + backends.H5NetCDFStore.set_attribute + backends.H5NetCDFStore.set_attributes + backends.H5NetCDFStore.set_dimension + backends.H5NetCDFStore.set_dimensions + backends.H5NetCDFStore.set_variable + backends.H5NetCDFStore.set_variables + backends.H5NetCDFStore.store + backends.H5NetCDFStore.store_dataset + backends.H5NetCDFStore.sync + backends.H5NetCDFStore.values + backends.H5NetCDFStore.attrs + backends.H5NetCDFStore.dimensions + backends.H5NetCDFStore.ds + backends.H5NetCDFStore.variables + + backends.PydapDataStore.close + backends.PydapDataStore.get + backends.PydapDataStore.get_attrs + backends.PydapDataStore.get_dimensions + backends.PydapDataStore.get_encoding + backends.PydapDataStore.get_variables + backends.PydapDataStore.items + backends.PydapDataStore.keys + backends.PydapDataStore.load + backends.PydapDataStore.open + backends.PydapDataStore.open_store_variable + backends.PydapDataStore.values + backends.PydapDataStore.attrs + backends.PydapDataStore.dimensions + backends.PydapDataStore.variables + + backends.ScipyDataStore.close + backends.ScipyDataStore.encode + backends.ScipyDataStore.encode_attribute + backends.ScipyDataStore.encode_variable + backends.ScipyDataStore.get + backends.ScipyDataStore.get_attrs + backends.ScipyDataStore.get_dimensions + backends.ScipyDataStore.get_encoding + backends.ScipyDataStore.get_variables + backends.ScipyDataStore.items + backends.ScipyDataStore.keys + backends.ScipyDataStore.load + backends.ScipyDataStore.open_store_variable + backends.ScipyDataStore.prepare_variable + backends.ScipyDataStore.set_attribute + backends.ScipyDataStore.set_attributes + backends.ScipyDataStore.set_dimension + backends.ScipyDataStore.set_dimensions + backends.ScipyDataStore.set_variable + backends.ScipyDataStore.set_variables + backends.ScipyDataStore.store + backends.ScipyDataStore.store_dataset + backends.ScipyDataStore.sync + backends.ScipyDataStore.values + backends.ScipyDataStore.attrs + backends.ScipyDataStore.dimensions + backends.ScipyDataStore.ds + backends.ScipyDataStore.variables + + backends.FileManager.acquire + backends.FileManager.acquire_context + backends.FileManager.close + + backends.CachingFileManager.acquire + backends.CachingFileManager.acquire_context + backends.CachingFileManager.close + + backends.DummyFileManager.acquire + backends.DummyFileManager.acquire_context + backends.DummyFileManager.close diff --git a/doc/api.rst b/doc/api.rst index a1fae3deb03..4492d882355 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -31,6 +31,8 @@ Top-level functions ones_like dot map_blocks + show_versions + set_options Dataset ======= @@ -74,7 +76,9 @@ and values given by ``DataArray`` objects. Dataset.__setitem__ Dataset.__delitem__ Dataset.update + Dataset.get Dataset.items + Dataset.keys Dataset.values Dataset contents @@ -410,7 +414,7 @@ Universal functions for the ``xarray.ufuncs`` module, which should not be used for new code unless compatibility with versions of NumPy prior to v1.13 is required. -This functions are copied from NumPy, but extended to work on NumPy arrays, +These functions are copied from NumPy, but extended to work on NumPy arrays, dask arrays and all xarray objects. You can find them in the ``xarray.ufuncs`` module: @@ -537,6 +541,15 @@ DataArray methods DataArray.unify_chunks DataArray.map_blocks +Coordinates objects +=================== + +.. autosummary:: + :toctree: generated/ + + core.coordinates.DataArrayCoordinates + core.coordinates.DatasetCoordinates + GroupBy objects =============== @@ -564,6 +577,16 @@ Rolling objects core.rolling.DatasetRolling.reduce core.rolling_exp.RollingExp +Coarsen objects +=============== + +.. autosummary:: + :toctree: generated/ + + core.rolling.DataArrayCoarsen + core.rolling.DatasetCoarsen + + Resample objects ================ @@ -593,6 +616,7 @@ Accessors :toctree: generated/ core.accessor_dt.DatetimeAccessor + core.accessor_dt.TimedeltaAccessor core.accessor_str.StringAccessor Custom Indexes @@ -625,8 +649,36 @@ Plotting plot.imshow plot.line plot.pcolormesh + plot.step plot.FacetGrid +Faceting +-------- +.. autosummary:: + :toctree: generated/ + + plot.FacetGrid + plot.FacetGrid.add_colorbar + plot.FacetGrid.add_legend + plot.FacetGrid.map + plot.FacetGrid.map_dataarray + plot.FacetGrid.map_dataarray_line + plot.FacetGrid.map_dataset + plot.FacetGrid.set_axis_labels + plot.FacetGrid.set_ticks + plot.FacetGrid.set_titles + plot.FacetGrid.set_xlabels + plot.FacetGrid.set_ylabels + +Tutorial +======== + +.. autosummary:: + :toctree: generated/ + + tutorial.open_dataset + tutorial.load_dataset + Testing ======= @@ -663,7 +715,7 @@ Advanced API These backends provide a low-level interface for lazily loading data from external file-formats or protocols, and can be manually invoked to create -arguments for the ``from_store`` and ``dump_to_store`` Dataset methods: +arguments for the ``load_store`` and ``dump_to_store`` Dataset methods: .. autosummary:: :toctree: generated/ @@ -679,6 +731,9 @@ arguments for the ``from_store`` and ``dump_to_store`` Dataset methods: Deprecated / Pending Deprecation ================================ +.. autosummary:: + :toctree: generated/ + Dataset.drop DataArray.drop Dataset.apply diff --git a/doc/conf.py b/doc/conf.py index 11abda6bb63..578f9cf550d 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -25,7 +25,7 @@ os.environ["PYTHONPATH"] = str(root) sys.path.insert(0, str(root)) -import xarray +import xarray # isort:skip allowed_failures = set() diff --git a/doc/contributing.rst b/doc/contributing.rst index 3cd0b3e8868..eb31db24591 100644 --- a/doc/contributing.rst +++ b/doc/contributing.rst @@ -231,9 +231,9 @@ About the *xarray* documentation -------------------------------- The documentation is written in **reStructuredText**, which is almost like writing -in plain English, and built using `Sphinx `__. The +in plain English, and built using `Sphinx `__. The Sphinx Documentation has an excellent `introduction to reST -`__. Review the Sphinx docs to perform more +`__. Review the Sphinx docs to perform more complex changes to the documentation as well. Some other important things to know about the docs: @@ -345,33 +345,31 @@ as possible to avoid mass breakages. Code Formatting ~~~~~~~~~~~~~~~ -Xarray uses `Black `_ and -`Flake8 `_ to ensure a consistent code -format throughout the project. ``black`` and ``flake8`` can be installed with +xarray uses several tools to ensure a consistent code format throughout the project: + +- `Black `_ for standardized code formatting +- `Flake8 `_ for general code quality +- `isort `_ for standardized order in imports. + See also `flake8-isort `_. +- `mypy `_ for static type checking on `type hints + `_ + ``pip``:: - pip install black flake8 + pip install black flake8 isort mypy and then run from the root of the Xarray repository:: - black . + isort -rc . + black -t py36 . flake8 + mypy . to auto-format your code. Additionally, many editors have plugins that will apply ``black`` as you edit files. -Other recommended but optional tools for checking code quality (not currently -enforced in CI): - -- `mypy `_ performs static type checking, which can - make it easier to catch bugs. Please run ``mypy xarray`` if you annotate any - code with `type hints `_. -- `isort `_ will highlight - incorrectly sorted imports. ``isort -y`` will automatically fix them. See - also `flake8-isort `_. - Optionally, you may wish to setup `pre-commit hooks `_ -to automatically run ``black`` and ``flake8`` when you make a git commit. This +to automatically run all the above tools every time you make a git commit. This can be done by installing ``pre-commit``:: pip install pre-commit @@ -380,25 +378,9 @@ and then running:: pre-commit install -from the root of the Xarray repository. Now ``black`` and ``flake8`` will be run -each time you commit changes. You can skip these checks with +from the root of the xarray repository. You can skip the pre-commit checks with ``git commit --no-verify``. -.. note:: - - If you were working on a branch *prior* to the code being reformatted with black, - you will likely face some merge conflicts. These steps can eliminate many of those - conflicts. Because they have had limited testing, please reach out to the core devs - on your pull request if you face any issues, and we'll help with the merge: - - - Merge the commit on master prior to the ``black`` commit into your branch - ``git merge f172c673``. If you have conflicts here, resolve and commit. - - Apply ``black .`` to your branch and commit ``git commit -am "black"`` - - Apply a patch of other changes we made on that commit: ``curl https://gist.githubusercontent.com/max-sixty/3cceb8472ed4ea806353999ca43aed52/raw/03cbee4e386156bddb61acaa250c0bfc726f596d/xarray%2520black%2520diff | git apply -`` - - Commit (``git commit -am "black2"``) - - Merge master at the ``black`` commit, resolving in favor of 'our' changes: - ``git merge d089df38 -X ours``. You shouldn't have any merge conflicts - - Merge current master ``git merge master``; resolve and commit any conflicts Backwards Compatibility ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/dask.rst b/doc/dask.rst index ed99ffaa896..07b3939af6e 100644 --- a/doc/dask.rst +++ b/doc/dask.rst @@ -130,6 +130,7 @@ will return a ``dask.delayed`` object that can be computed later. A dataset can also be converted to a Dask DataFrame using :py:meth:`~xarray.Dataset.to_dask_dataframe`. .. ipython:: python + :okwarning: df = ds.to_dask_dataframe() df diff --git a/doc/data-structures.rst b/doc/data-structures.rst index 504d820a234..70e34adabed 100644 --- a/doc/data-structures.rst +++ b/doc/data-structures.rst @@ -353,6 +353,8 @@ setting) variables and attributes: This is particularly useful in an exploratory context, because you can tab-complete these variable names with tools like IPython. +.. _dictionary_like_methods: + Dictionary like methods ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/examples.rst b/doc/examples.rst index ce56102cc9d..3067ca824be 100644 --- a/doc/examples.rst +++ b/doc/examples.rst @@ -10,3 +10,10 @@ Examples examples/visualization_gallery examples/ROMS_ocean_model examples/ERA5-GRIB-example + +Using apply_ufunc +------------------ +.. toctree:: + :maxdepth: 2 + + examples/apply_ufunc_vectorize_1d diff --git a/doc/examples/_code/weather_data_setup.py b/doc/examples/_code/weather_data_setup.py deleted file mode 100644 index 4e4e2ab176e..00000000000 --- a/doc/examples/_code/weather_data_setup.py +++ /dev/null @@ -1,22 +0,0 @@ -import numpy as np -import pandas as pd -import seaborn as sns - -import xarray as xr - -np.random.seed(123) - -times = pd.date_range("2000-01-01", "2001-12-31", name="time") -annual_cycle = np.sin(2 * np.pi * (times.dayofyear.values / 365.25 - 0.28)) - -base = 10 + 15 * annual_cycle.reshape(-1, 1) -tmin_values = base + 3 * np.random.randn(annual_cycle.size, 3) -tmax_values = base + 10 + 3 * np.random.randn(annual_cycle.size, 3) - -ds = xr.Dataset( - { - "tmin": (("time", "location"), tmin_values), - "tmax": (("time", "location"), tmax_values), - }, - {"time": times, "location": ["IA", "IN", "IL"]}, -) diff --git a/doc/examples/apply_ufunc_vectorize_1d.ipynb b/doc/examples/apply_ufunc_vectorize_1d.ipynb new file mode 100644 index 00000000000..6d18d48fdb5 --- /dev/null +++ b/doc/examples/apply_ufunc_vectorize_1d.ipynb @@ -0,0 +1,736 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Applying unvectorized functions with `apply_ufunc`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This example will illustrate how to conveniently apply an unvectorized function `func` to xarray objects using `apply_ufunc`. `func` expects 1D numpy arrays and returns a 1D numpy array. Our goal is to coveniently apply this function along a dimension of xarray objects that may or may not wrap dask arrays with a signature.\n", + "\n", + "We will illustrate this using `np.interp`: \n", + "\n", + " Signature: np.interp(x, xp, fp, left=None, right=None, period=None)\n", + " Docstring:\n", + " One-dimensional linear interpolation.\n", + "\n", + " Returns the one-dimensional piecewise linear interpolant to a function\n", + " with given discrete data points (`xp`, `fp`), evaluated at `x`.\n", + "\n", + "and write an `xr_interp` function with signature\n", + "\n", + " xr_interp(xarray_object, dimension_name, new_coordinate_to_interpolate_to)\n", + "\n", + "### Load data\n", + "\n", + "First lets load an example dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:45:51.659160Z", + "start_time": "2020-01-15T14:45:50.528742Z" + } + }, + "outputs": [], + "source": [ + "import xarray as xr\n", + "import numpy as np\n", + "\n", + "xr.set_options(display_style=\"html\") # fancy HTML repr\n", + "\n", + "air = (\n", + " xr.tutorial.load_dataset(\"air_temperature\")\n", + " .air.sortby(\"lat\") # np.interp needs coordinate in ascending order\n", + " .isel(time=slice(4), lon=slice(3))\n", + ") # choose a small subset for convenience\n", + "air" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The function we will apply is `np.interp` which expects 1D numpy arrays. This functionality is already implemented in xarray so we use that capability to make sure we are not making mistakes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:45:55.431708Z", + "start_time": "2020-01-15T14:45:55.104701Z" + } + }, + "outputs": [], + "source": [ + "newlat = np.linspace(15, 75, 100)\n", + "air.interp(lat=newlat)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's define a function that works with one vector of data along `lat` at a time." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:45:57.889496Z", + "start_time": "2020-01-15T14:45:57.792269Z" + } + }, + "outputs": [], + "source": [ + "def interp1d_np(data, x, xi):\n", + " return np.interp(xi, x, data)\n", + "\n", + "\n", + "interped = interp1d_np(air.isel(time=0, lon=0), air.lat, newlat)\n", + "expected = air.interp(lat=newlat)\n", + "\n", + "# no errors are raised if values are equal to within floating point precision\n", + "np.testing.assert_allclose(expected.isel(time=0, lon=0).values, interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### No errors are raised so our interpolation is working.\n", + "\n", + "This function consumes and returns numpy arrays, which means we need to do a lot of work to convert the result back to an xarray object with meaningful metadata. This is where `apply_ufunc` is very useful.\n", + "\n", + "### `apply_ufunc`\n", + "\n", + " Apply a vectorized function for unlabeled arrays on xarray objects.\n", + "\n", + " The function will be mapped over the data variable(s) of the input arguments using \n", + " xarray’s standard rules for labeled computation, including alignment, broadcasting, \n", + " looping over GroupBy/Dataset variables, and merging of coordinates.\n", + " \n", + "`apply_ufunc` has many capabilities but for simplicity this example will focus on the common task of vectorizing 1D functions over nD xarray objects. We will iteratively build up the right set of arguments to `apply_ufunc` and read through many error messages in doing so." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:45:59.768626Z", + "start_time": "2020-01-15T14:45:59.543808Z" + } + }, + "outputs": [], + "source": [ + "xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(time=0, lon=0), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`apply_ufunc` needs to know a lot of information about what our function does so that it can reconstruct the outputs. In this case, the size of dimension lat has changed and we need to explicitly specify that this will happen. xarray helpfully tells us that we need to specify the kwarg `exclude_dims`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `exclude_dims`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "```\n", + "exclude_dims : set, optional\n", + " Core dimensions on the inputs to exclude from alignment and\n", + " broadcasting entirely. Any input coordinates along these dimensions\n", + " will be dropped. Each excluded dimension must also appear in\n", + " ``input_core_dims`` for at least one argument. Only dimensions listed\n", + " here are allowed to change size between input and output objects.\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:02.187012Z", + "start_time": "2020-01-15T14:46:02.105563Z" + } + }, + "outputs": [], + "source": [ + "xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(time=0, lon=0), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Core dimensions\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Core dimensions are central to using `apply_ufunc`. In our case, our function expects to receive a 1D vector along `lat` — this is the dimension that is \"core\" to the function's functionality. Multiple core dimensions are possible. `apply_ufunc` needs to know which dimensions of each variable are core dimensions.\n", + "\n", + " input_core_dims : Sequence[Sequence], optional\n", + " List of the same length as ``args`` giving the list of core dimensions\n", + " on each input argument that should not be broadcast. By default, we\n", + " assume there are no core dimensions on any input arguments.\n", + "\n", + " For example, ``input_core_dims=[[], ['time']]`` indicates that all\n", + " dimensions on the first argument and all dimensions other than 'time'\n", + " on the second argument should be broadcast.\n", + "\n", + " Core dimensions are automatically moved to the last axes of input\n", + " variables before applying ``func``, which facilitates using NumPy style\n", + " generalized ufuncs [2]_.\n", + " \n", + " output_core_dims : List[tuple], optional\n", + " List of the same length as the number of output arguments from\n", + " ``func``, giving the list of core dimensions on each output that were\n", + " not broadcast on the inputs. By default, we assume that ``func``\n", + " outputs exactly one array, with axes corresponding to each broadcast\n", + " dimension.\n", + "\n", + " Core dimensions are assumed to appear as the last dimensions of each\n", + " output in the provided order.\n", + " \n", + "Next we specify `\"lat\"` as `input_core_dims` on both `air` and `air.lat`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:05.031672Z", + "start_time": "2020-01-15T14:46:04.947588Z" + } + }, + "outputs": [], + "source": [ + "xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(time=0, lon=0), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + " input_core_dims=[[\"lat\"], [\"lat\"], []],\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "xarray is telling us that it expected to receive back a numpy array with 0 dimensions but instead received an array with 1 dimension corresponding to `newlat`. We can fix this by specifying `output_core_dims`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:09.325218Z", + "start_time": "2020-01-15T14:46:09.303020Z" + } + }, + "outputs": [], + "source": [ + "xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(time=0, lon=0), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + " input_core_dims=[[\"lat\"], [\"lat\"], []], # list with one entry per arg\n", + " output_core_dims=[[\"lat\"]],\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally we get some output! Let's check that this is right\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:11.295440Z", + "start_time": "2020-01-15T14:46:11.226553Z" + } + }, + "outputs": [], + "source": [ + "interped = xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(time=0, lon=0), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + " input_core_dims=[[\"lat\"], [\"lat\"], []], # list with one entry per arg\n", + " output_core_dims=[[\"lat\"]],\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + ")\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(expected.isel(time=0, lon=0), interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "No errors are raised so it is right!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Vectorization with `np.vectorize`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now our function currently only works on one vector of data which is not so useful given our 3D dataset.\n", + "Let's try passing the whole dataset. We add a `print` statement so we can see what our function receives." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:13.808646Z", + "start_time": "2020-01-15T14:46:13.680098Z" + } + }, + "outputs": [], + "source": [ + "def interp1d_np(data, x, xi):\n", + " print(f\"data: {data.shape} | x: {x.shape} | xi: {xi.shape}\")\n", + " return np.interp(xi, x, data)\n", + "\n", + "\n", + "interped = xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(\n", + " lon=slice(3), time=slice(4)\n", + " ), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + " input_core_dims=[[\"lat\"], [\"lat\"], []], # list with one entry per arg\n", + " output_core_dims=[[\"lat\"]],\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + ")\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(expected.isel(time=0, lon=0), interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "That's a hard-to-interpret error but our `print` call helpfully printed the shapes of the input data: \n", + "\n", + " data: (10, 53, 25) | x: (25,) | xi: (100,)\n", + "\n", + "We see that `air` has been passed as a 3D numpy array which is not what `np.interp` expects. Instead we want loop over all combinations of `lon` and `time`; and apply our function to each corresponding vector of data along `lat`.\n", + "`apply_ufunc` makes this easy by specifying `vectorize=True`:\n", + "\n", + " vectorize : bool, optional\n", + " If True, then assume ``func`` only takes arrays defined over core\n", + " dimensions as input and vectorize it automatically with\n", + " :py:func:`numpy.vectorize`. This option exists for convenience, but is\n", + " almost always slower than supplying a pre-vectorized function.\n", + " Using this option requires NumPy version 1.12 or newer.\n", + " \n", + "Also see the documentation for `np.vectorize`: https://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html. Most importantly\n", + "\n", + " The vectorize function is provided primarily for convenience, not for performance. \n", + " The implementation is essentially a for loop." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:26.633233Z", + "start_time": "2020-01-15T14:46:26.515209Z" + } + }, + "outputs": [], + "source": [ + "def interp1d_np(data, x, xi):\n", + " print(f\"data: {data.shape} | x: {x.shape} | xi: {xi.shape}\")\n", + " return np.interp(xi, x, data)\n", + "\n", + "\n", + "interped = xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air, # now arguments in the order expected by 'interp1_np'\n", + " air.lat, # as above\n", + " newlat, # as above\n", + " input_core_dims=[[\"lat\"], [\"lat\"], []], # list with one entry per arg\n", + " output_core_dims=[[\"lat\"]], # returned data has one dimension\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + " vectorize=True, # loop over non-core dims\n", + ")\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(expected, interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This unfortunately is another cryptic error from numpy. \n", + "\n", + "Notice that `newlat` is not an xarray object. Let's add a dimension name `new_lat` and modify the call. Note this cannot be `lat` because xarray expects dimensions to be the same size (or broadcastable) among all inputs. `output_core_dims` needs to be modified appropriately. We'll manually rename `new_lat` back to `lat` for easy checking." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:30.026663Z", + "start_time": "2020-01-15T14:46:29.893267Z" + } + }, + "outputs": [], + "source": [ + "def interp1d_np(data, x, xi):\n", + " print(f\"data: {data.shape} | x: {x.shape} | xi: {xi.shape}\")\n", + " return np.interp(xi, x, data)\n", + "\n", + "\n", + "interped = xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air, # now arguments in the order expected by 'interp1_np'\n", + " air.lat, # as above\n", + " newlat, # as above\n", + " input_core_dims=[[\"lat\"], [\"lat\"], [\"new_lat\"]], # list with one entry per arg\n", + " output_core_dims=[[\"new_lat\"]], # returned data has one dimension\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be a set!\n", + " vectorize=True, # loop over non-core dims\n", + ")\n", + "interped = interped.rename({\"new_lat\": \"lat\"})\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(\n", + " expected.transpose(*interped.dims), interped # order of dims is different\n", + ")\n", + "interped" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that the printed input shapes are all 1D and correspond to one vector along the `lat` dimension.\n", + "\n", + "The result is now an xarray object with coordinate values copied over from `data`. This is why `apply_ufunc` is so convenient; it takes care of a lot of boilerplate necessary to apply functions that consume and produce numpy arrays to xarray objects.\n", + "\n", + "One final point: `lat` is now the *last* dimension in `interped`. This is a \"property\" of core dimensions: they are moved to the end before being sent to `interp1d_np` as was noted in the docstring for `input_core_dims`\n", + "\n", + " Core dimensions are automatically moved to the last axes of input\n", + " variables before applying ``func``, which facilitates using NumPy style\n", + " generalized ufuncs [2]_." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Parallelization with dask\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "So far our function can only handle numpy arrays. A real benefit of `apply_ufunc` is the ability to easily parallelize over dask chunks _when needed_. \n", + "\n", + "We want to apply this function in a vectorized fashion over each chunk of the dask array. This is possible using dask's `blockwise` or `map_blocks`. `apply_ufunc` wraps `blockwise` and asking it to map the function over chunks using `blockwise` is as simple as specifying `dask=\"parallelized\"`. With this level of flexibility we need to provide dask with some extra information: \n", + " 1. `output_dtypes`: dtypes of all returned objects, and \n", + " 2. `output_sizes`: lengths of any new dimensions. \n", + " \n", + "Here we need to specify `output_dtypes` since `apply_ufunc` can infer the size of the new dimension `new_lat` from the argument corresponding to the third element in `input_core_dims`. Here I choose the chunk sizes to illustrate that `np.vectorize` is still applied so that our function receives 1D vectors even though the blocks are 3D." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:48:42.469341Z", + "start_time": "2020-01-15T14:48:42.344209Z" + } + }, + "outputs": [], + "source": [ + "def interp1d_np(data, x, xi):\n", + " print(f\"data: {data.shape} | x: {x.shape} | xi: {xi.shape}\")\n", + " return np.interp(xi, x, data)\n", + "\n", + "\n", + "interped = xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.chunk(\n", + " {\"time\": 2, \"lon\": 2}\n", + " ), # now arguments in the order expected by 'interp1_np'\n", + " air.lat, # as above\n", + " newlat, # as above\n", + " input_core_dims=[[\"lat\"], [\"lat\"], [\"new_lat\"]], # list with one entry per arg\n", + " output_core_dims=[[\"new_lat\"]], # returned data has one dimension\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be a set!\n", + " vectorize=True, # loop over non-core dims\n", + " dask=\"parallelized\",\n", + " output_dtypes=[air.dtype], # one per output\n", + ").rename({\"new_lat\": \"lat\"})\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(expected.transpose(*interped.dims), interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Yay! our function is receiving 1D vectors, so we've successfully parallelized applying a 1D function over a block. If you have a distributed dashboard up, you should see computes happening as equality is checked.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### High performance vectorization: gufuncs, numba & guvectorize\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`np.vectorize` is a very convenient function but is unfortunately slow. It is only marginally faster than writing a for loop in Python and looping. A common way to get around this is to write a base interpolation function that can handle nD arrays in a compiled language like Fortran and then pass that to `apply_ufunc`.\n", + "\n", + "Another option is to use the numba package which provides a very convenient `guvectorize` decorator: https://numba.pydata.org/numba-doc/latest/user/vectorize.html#the-guvectorize-decorator\n", + "\n", + "Any decorated function gets compiled and will loop over any non-core dimension in parallel when necessary. We need to specify some extra information:\n", + "\n", + " 1. Our function cannot return a variable any more. Instead it must receive a variable (the last argument) whose contents the function will modify. So we change from `def interp1d_np(data, x, xi)` to `def interp1d_np_gufunc(data, x, xi, out)`. Our computed results must be assigned to `out`. All values of `out` must be assigned explicitly.\n", + " \n", + " 2. `guvectorize` needs to know the dtypes of the input and output. This is specified in string form as the first argument. Each element of the tuple corresponds to each argument of the function. In this case, we specify `float64` for all inputs and outputs: `\"(float64[:], float64[:], float64[:], float64[:])\"` corresponding to `data, x, xi, out`\n", + " \n", + " 3. Now we need to tell numba the size of the dimensions the function takes as inputs and returns as output i.e. core dimensions. This is done in symbolic form i.e. `data` and `x` are vectors of the same length, say `n`; `xi` and the output `out` have a different length, say `m`. So the second argument is (again as a string)\n", + " `\"(n), (n), (m) -> (m).\"` corresponding again to `data, x, xi, out`\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:48:45.267633Z", + "start_time": "2020-01-15T14:48:44.943939Z" + } + }, + "outputs": [], + "source": [ + "from numba import float64, guvectorize\n", + "\n", + "\n", + "@guvectorize(\"(float64[:], float64[:], float64[:], float64[:])\", \"(n), (n), (m) -> (m)\")\n", + "def interp1d_np_gufunc(data, x, xi, out):\n", + " # numba doesn't really like this.\n", + " # seem to support fstrings so do it the old way\n", + " print(\n", + " \"data: \" + str(data.shape) + \" | x:\" + str(x.shape) + \" | xi: \" + str(xi.shape)\n", + " )\n", + " out[:] = np.interp(xi, x, data)\n", + " # gufuncs don't return data\n", + " # instead you assign to a the last arg\n", + " # return np.interp(xi, x, data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The warnings are about object-mode compilation relating to the `print` statement. This means we don't get much speed up: https://numba.pydata.org/numba-doc/latest/user/performance-tips.html#no-python-mode-vs-object-mode. We'll keep the `print` statement temporarily to make sure that `guvectorize` acts like we want it to." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:48:54.755405Z", + "start_time": "2020-01-15T14:48:54.634724Z" + } + }, + "outputs": [], + "source": [ + "interped = xr.apply_ufunc(\n", + " interp1d_np_gufunc, # first the function\n", + " air.chunk(\n", + " {\"time\": 2, \"lon\": 2}\n", + " ), # now arguments in the order expected by 'interp1_np'\n", + " air.lat, # as above\n", + " newlat, # as above\n", + " input_core_dims=[[\"lat\"], [\"lat\"], [\"new_lat\"]], # list with one entry per arg\n", + " output_core_dims=[[\"new_lat\"]], # returned data has one dimension\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be a set!\n", + " # vectorize=True, # not needed since numba takes care of vectorizing\n", + " dask=\"parallelized\",\n", + " output_dtypes=[air.dtype], # one per output\n", + ").rename({\"new_lat\": \"lat\"})\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(expected.transpose(*interped.dims), interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Yay! Our function is receiving 1D vectors and is working automatically with dask arrays. Finally let's comment out the print line and wrap everything up in a nice reusable function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:49:28.667528Z", + "start_time": "2020-01-15T14:49:28.103914Z" + } + }, + "outputs": [], + "source": [ + "from numba import float64, guvectorize\n", + "\n", + "\n", + "@guvectorize(\n", + " \"(float64[:], float64[:], float64[:], float64[:])\",\n", + " \"(n), (n), (m) -> (m)\",\n", + " nopython=True,\n", + ")\n", + "def interp1d_np_gufunc(data, x, xi, out):\n", + " out[:] = np.interp(xi, x, data)\n", + "\n", + "\n", + "def xr_interp(data, dim, newdim):\n", + "\n", + " interped = xr.apply_ufunc(\n", + " interp1d_np_gufunc, # first the function\n", + " data, # now arguments in the order expected by 'interp1_np'\n", + " data[dim], # as above\n", + " newdim, # as above\n", + " input_core_dims=[[dim], [dim], [\"__newdim__\"]], # list with one entry per arg\n", + " output_core_dims=[[\"__newdim__\"]], # returned data has one dimension\n", + " exclude_dims=set((dim,)), # dimensions allowed to change size. Must be a set!\n", + " # vectorize=True, # not needed since numba takes care of vectorizing\n", + " dask=\"parallelized\",\n", + " output_dtypes=[data.dtype], # one per output; could also be float or np.dtype(\"float64\")\n", + " ).rename({\"__newdim__\": dim})\n", + " interped[dim] = newdim # need to add this manually\n", + "\n", + " return interped\n", + "\n", + "\n", + "xr.testing.assert_allclose(\n", + " expected.transpose(*interped.dims),\n", + " xr_interp(air.chunk({\"time\": 2, \"lon\": 2}), \"lat\", newlat),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This technique is generalizable to any 1D function." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + }, + "nbsphinx": { + "allow_errors": true + }, + "org": null, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": false, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/doc/examples/weather-data.ipynb b/doc/examples/weather-data.ipynb new file mode 100644 index 00000000000..f582453aacf --- /dev/null +++ b/doc/examples/weather-data.ipynb @@ -0,0 +1,374 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Toy weather data\n", + "\n", + "Here is an example of how to easily manipulate a toy weather dataset using\n", + "xarray and other recommended Python libraries:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:43:36.127628Z", + "start_time": "2020-01-27T15:43:36.081733Z" + } + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "\n", + "import xarray as xr\n", + "\n", + "np.random.seed(123)\n", + "\n", + "xr.set_options(display_style=\"html\")\n", + "\n", + "times = pd.date_range(\"2000-01-01\", \"2001-12-31\", name=\"time\")\n", + "annual_cycle = np.sin(2 * np.pi * (times.dayofyear.values / 365.25 - 0.28))\n", + "\n", + "base = 10 + 15 * annual_cycle.reshape(-1, 1)\n", + "tmin_values = base + 3 * np.random.randn(annual_cycle.size, 3)\n", + "tmax_values = base + 10 + 3 * np.random.randn(annual_cycle.size, 3)\n", + "\n", + "ds = xr.Dataset(\n", + " {\n", + " \"tmin\": ((\"time\", \"location\"), tmin_values),\n", + " \"tmax\": ((\"time\", \"location\"), tmax_values),\n", + " },\n", + " {\"time\": times, \"location\": [\"IA\", \"IN\", \"IL\"]},\n", + ")\n", + "\n", + "ds" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Examine a dataset with pandas and seaborn" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Convert to a pandas DataFrame" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:47:14.160297Z", + "start_time": "2020-01-27T15:47:14.126738Z" + } + }, + "outputs": [], + "source": [ + "df = ds.to_dataframe()\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:47:32.682065Z", + "start_time": "2020-01-27T15:47:32.652629Z" + } + }, + "outputs": [], + "source": [ + "df.describe()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize using pandas" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:47:34.617042Z", + "start_time": "2020-01-27T15:47:34.282605Z" + } + }, + "outputs": [], + "source": [ + "ds.mean(dim=\"location\").to_dataframe().plot()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize using seaborn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:47:37.643175Z", + "start_time": "2020-01-27T15:47:37.202479Z" + } + }, + "outputs": [], + "source": [ + "sns.pairplot(df.reset_index(), vars=ds.data_vars)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Probability of freeze by calendar month" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:48:11.241224Z", + "start_time": "2020-01-27T15:48:11.211156Z" + } + }, + "outputs": [], + "source": [ + "freeze = (ds[\"tmin\"] <= 0).groupby(\"time.month\").mean(\"time\")\n", + "freeze" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:48:13.131247Z", + "start_time": "2020-01-27T15:48:12.924985Z" + } + }, + "outputs": [], + "source": [ + "freeze.to_pandas().plot()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Monthly averaging" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:48:08.498259Z", + "start_time": "2020-01-27T15:48:08.210890Z" + } + }, + "outputs": [], + "source": [ + "monthly_avg = ds.resample(time=\"1MS\").mean()\n", + "monthly_avg.sel(location=\"IA\").to_dataframe().plot(style=\"s-\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that ``MS`` here refers to Month-Start; ``M`` labels Month-End (the last day of the month)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Calculate monthly anomalies" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In climatology, \"anomalies\" refer to the difference between observations and\n", + "typical weather for a particular season. Unlike observations, anomalies should\n", + "not show any seasonal cycle." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:49:34.855086Z", + "start_time": "2020-01-27T15:49:34.406439Z" + } + }, + "outputs": [], + "source": [ + "climatology = ds.groupby(\"time.month\").mean(\"time\")\n", + "anomalies = ds.groupby(\"time.month\") - climatology\n", + "anomalies.mean(\"location\").to_dataframe()[[\"tmin\", \"tmax\"]].plot()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Calculate standardized monthly anomalies" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can create standardized anomalies where the difference between the\n", + "observations and the climatological monthly mean is\n", + "divided by the climatological standard deviation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:50:09.144586Z", + "start_time": "2020-01-27T15:50:08.734682Z" + } + }, + "outputs": [], + "source": [ + "climatology_mean = ds.groupby(\"time.month\").mean(\"time\")\n", + "climatology_std = ds.groupby(\"time.month\").std(\"time\")\n", + "stand_anomalies = xr.apply_ufunc(\n", + " lambda x, m, s: (x - m) / s,\n", + " ds.groupby(\"time.month\"),\n", + " climatology_mean,\n", + " climatology_std,\n", + ")\n", + "\n", + "stand_anomalies.mean(\"location\").to_dataframe()[[\"tmin\", \"tmax\"]].plot()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Fill missing values with climatology" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:50:46.192491Z", + "start_time": "2020-01-27T15:50:46.174554Z" + } + }, + "source": [ + "The ``fillna`` method on grouped objects lets you easily fill missing values by group:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:51:40.279299Z", + "start_time": "2020-01-27T15:51:40.220342Z" + } + }, + "outputs": [], + "source": [ + "# throw away the first half of every month\n", + "some_missing = ds.tmin.sel(time=ds[\"time.day\"] > 15).reindex_like(ds)\n", + "filled = some_missing.groupby(\"time.month\").fillna(climatology.tmin)\n", + "both = xr.Dataset({\"some_missing\": some_missing, \"filled\": filled})\n", + "both" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:52:11.815769Z", + "start_time": "2020-01-27T15:52:11.770825Z" + } + }, + "outputs": [], + "source": [ + "df = both.sel(time=\"2000\").mean(\"location\").reset_coords(drop=True).to_dataframe()\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-27T15:52:14.867866Z", + "start_time": "2020-01-27T15:52:14.449684Z" + } + }, + "outputs": [], + "source": [ + "df[[\"filled\", \"some_missing\"]].plot()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": true, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/doc/examples/weather-data.rst b/doc/examples/weather-data.rst deleted file mode 100644 index 5a019e637c4..00000000000 --- a/doc/examples/weather-data.rst +++ /dev/null @@ -1,138 +0,0 @@ -.. _toy weather data: - -Toy weather data -================ - -Here is an example of how to easily manipulate a toy weather dataset using -xarray and other recommended Python libraries: - -.. contents:: - :local: - :depth: 1 - -Shared setup: - -.. literalinclude:: _code/weather_data_setup.py - -.. ipython:: python - :suppress: - - fpath = "examples/_code/weather_data_setup.py" - with open(fpath) as f: - code = compile(f.read(), fpath, 'exec') - exec(code) - - -Examine a dataset with pandas_ and seaborn_ -------------------------------------------- - -.. _pandas: http://pandas.pydata.org -.. _seaborn: http://stanford.edu/~mwaskom/software/seaborn - -.. ipython:: python - - ds - - df = ds.to_dataframe() - - df.head() - - df.describe() - - @savefig examples_tmin_tmax_plot.png - ds.mean(dim='location').to_dataframe().plot() - - -.. ipython:: python - - @savefig examples_pairplot.png - sns.pairplot(df.reset_index(), vars=ds.data_vars) - -.. _average by month: - -Probability of freeze by calendar month ---------------------------------------- - -.. ipython:: python - - freeze = (ds['tmin'] <= 0).groupby('time.month').mean('time') - freeze - - @savefig examples_freeze_prob.png - freeze.to_pandas().plot() - -.. _monthly average: - -Monthly averaging ------------------ - -.. ipython:: python - - monthly_avg = ds.resample(time='1MS').mean() - - @savefig examples_tmin_tmax_plot_mean.png - monthly_avg.sel(location='IA').to_dataframe().plot(style='s-') - -Note that ``MS`` here refers to Month-Start; ``M`` labels Month-End (the last -day of the month). - -.. _monthly anomalies: - -Calculate monthly anomalies ---------------------------- - -In climatology, "anomalies" refer to the difference between observations and -typical weather for a particular season. Unlike observations, anomalies should -not show any seasonal cycle. - -.. ipython:: python - - climatology = ds.groupby('time.month').mean('time') - anomalies = ds.groupby('time.month') - climatology - - @savefig examples_anomalies_plot.png - anomalies.mean('location').to_dataframe()[['tmin', 'tmax']].plot() - -.. _standardized monthly anomalies: - -Calculate standardized monthly anomalies ----------------------------------------- - -You can create standardized anomalies where the difference between the -observations and the climatological monthly mean is -divided by the climatological standard deviation. - -.. ipython:: python - - climatology_mean = ds.groupby('time.month').mean('time') - climatology_std = ds.groupby('time.month').std('time') - stand_anomalies = xr.apply_ufunc( - lambda x, m, s: (x - m) / s, - ds.groupby('time.month'), - climatology_mean, climatology_std) - - @savefig examples_standardized_anomalies_plot.png - stand_anomalies.mean('location').to_dataframe()[['tmin', 'tmax']].plot() - -.. _fill with climatology: - -Fill missing values with climatology ------------------------------------- - -The :py:func:`~xarray.Dataset.fillna` method on grouped objects lets you easily -fill missing values by group: - -.. ipython:: python - :okwarning: - - # throw away the first half of every month - some_missing = ds.tmin.sel(time=ds['time.day'] > 15).reindex_like(ds) - filled = some_missing.groupby('time.month').fillna(climatology.tmin) - - both = xr.Dataset({'some_missing': some_missing, 'filled': filled}) - both - - df = both.sel(time='2000').mean('location').reset_coords(drop=True).to_dataframe() - - @savefig examples_filled.png - df[['filled', 'some_missing']].plot() diff --git a/doc/groupby.rst b/doc/groupby.rst index f5943703765..927e192eb6c 100644 --- a/doc/groupby.rst +++ b/doc/groupby.rst @@ -94,7 +94,7 @@ Apply ~~~~~ To apply a function to each group, you can use the flexible -:py:meth:`~xarray.DatasetGroupBy.map` method. The resulting objects are automatically +:py:meth:`~xarray.core.groupby.DatasetGroupBy.map` method. The resulting objects are automatically concatenated back together along the group axis: .. ipython:: python @@ -104,8 +104,8 @@ concatenated back together along the group axis: arr.groupby('letters').map(standardize) -GroupBy objects also have a :py:meth:`~xarray.DatasetGroupBy.reduce` method and -methods like :py:meth:`~xarray.DatasetGroupBy.mean` as shortcuts for applying an +GroupBy objects also have a :py:meth:`~xarray.core.groupby.DatasetGroupBy.reduce` method and +methods like :py:meth:`~xarray.core.groupby.DatasetGroupBy.mean` as shortcuts for applying an aggregation function: .. ipython:: python diff --git a/doc/howdoi.rst b/doc/howdoi.rst index 91644ba2718..84c0c786027 100644 --- a/doc/howdoi.rst +++ b/doc/howdoi.rst @@ -11,6 +11,8 @@ How do I ... * - How do I... - Solution + * - add a DataArray to my dataset as a new variable + - ``my_dataset[varname] = my_dataArray`` or :py:meth:`Dataset.assign` (see also :ref:`dictionary_like_methods`) * - add variables from other datasets to my dataset - :py:meth:`Dataset.merge` * - add a new dimension and/or coordinate @@ -22,7 +24,7 @@ How do I ... * - change the order of dimensions - :py:meth:`DataArray.transpose`, :py:meth:`Dataset.transpose` * - remove a variable from my object - - :py:meth:`Dataset.drop`, :py:meth:`DataArray.drop` + - :py:meth:`Dataset.drop_vars`, :py:meth:`DataArray.drop_vars` * - remove dimensions of length 1 or 0 - :py:meth:`DataArray.squeeze`, :py:meth:`Dataset.squeeze` * - remove all variables with a particular dimension @@ -48,7 +50,7 @@ How do I ... * - write xarray objects with complex values to a netCDF file - :py:func:`Dataset.to_netcdf`, :py:func:`DataArray.to_netcdf` specifying ``engine="h5netcdf", invalid_netcdf=True`` * - make xarray objects look like other xarray objects - - :py:func:`~xarray.ones_like`, :py:func:`~xarray.zeros_like`, :py:func:`~xarray.full_like`, :py:meth:`Dataset.reindex_like`, :py:meth:`Dataset.interpolate_like`, :py:meth:`Dataset.broadcast_like`, :py:meth:`DataArray.reindex_like`, :py:meth:`DataArray.interpolate_like`, :py:meth:`DataArray.broadcast_like` + - :py:func:`~xarray.ones_like`, :py:func:`~xarray.zeros_like`, :py:func:`~xarray.full_like`, :py:meth:`Dataset.reindex_like`, :py:meth:`Dataset.interp_like`, :py:meth:`Dataset.broadcast_like`, :py:meth:`DataArray.reindex_like`, :py:meth:`DataArray.interp_like`, :py:meth:`DataArray.broadcast_like` * - replace NaNs with other values - :py:meth:`Dataset.fillna`, :py:meth:`Dataset.ffill`, :py:meth:`Dataset.bfill`, :py:meth:`Dataset.interpolate_na`, :py:meth:`DataArray.fillna`, :py:meth:`DataArray.ffill`, :py:meth:`DataArray.bfill`, :py:meth:`DataArray.interpolate_na` * - extract the year, month, day or similar from a DataArray of time values @@ -57,3 +59,4 @@ How do I ... - ``obj.dt.ceil``, ``obj.dt.floor``, ``obj.dt.round``. See :ref:`dt_accessor` for more. * - make a mask that is ``True`` where an object contains any of the values in a array - :py:meth:`Dataset.isin`, :py:meth:`DataArray.isin` + diff --git a/doc/indexing.rst b/doc/indexing.rst index e8482ac66b3..cfbb84a8343 100644 --- a/doc/indexing.rst +++ b/doc/indexing.rst @@ -132,7 +132,7 @@ use them explicitly to slice data. There are two ways to do this: The arguments to these methods can be any objects that could index the array along the dimension given by the keyword, e.g., labels for an individual value, -Python :py:func:`slice` objects or 1-dimensional arrays. +Python :py:class:`slice` objects or 1-dimensional arrays. .. note:: diff --git a/doc/installing.rst b/doc/installing.rst index 219cf109efe..dfc2841a956 100644 --- a/doc/installing.rst +++ b/doc/installing.rst @@ -7,8 +7,9 @@ Required dependencies --------------------- - Python (3.6 or later) -- `numpy `__ (1.14 or later) -- `pandas `__ (0.24 or later) +- setuptools +- `numpy `__ (1.15 or later) +- `pandas `__ (0.25 or later) Optional dependencies --------------------- @@ -58,7 +59,7 @@ For plotting - `matplotlib `__: required for :ref:`plotting` - `cartopy `__: recommended for :ref:`plot-maps` -- `seaborn `__: for better +- `seaborn `__: for better color palettes - `nc-time-axis `__: for plotting cftime.datetime objects diff --git a/doc/interpolation.rst b/doc/interpolation.rst index 7c750506cf3..63e9a7cd35e 100644 --- a/doc/interpolation.rst +++ b/doc/interpolation.rst @@ -48,7 +48,7 @@ array-like, which gives the interpolated result as an array. # interpolation da.interp(time=[2.5, 3.5]) -To interpolate data with a :py:func:`numpy.datetime64` coordinate you can pass a string. +To interpolate data with a :py:doc:`numpy.datetime64 ` coordinate you can pass a string. .. ipython:: python @@ -128,7 +128,7 @@ It is now possible to safely compute the difference ``other - interpolated``. Interpolation methods --------------------- -We use :py:func:`scipy.interpolate.interp1d` for 1-dimensional interpolation and +We use :py:class:`scipy.interpolate.interp1d` for 1-dimensional interpolation and :py:func:`scipy.interpolate.interpn` for multi-dimensional interpolation. The interpolation method can be specified by the optional ``method`` argument. diff --git a/doc/io.rst b/doc/io.rst index 8f8a776f73a..e910943236f 100644 --- a/doc/io.rst +++ b/doc/io.rst @@ -1,3 +1,4 @@ +.. currentmodule:: xarray .. _io: Reading and writing files @@ -23,8 +24,8 @@ netCDF The recommended way to store xarray data structures is `netCDF`__, which is a binary file format for self-described datasets that originated in the geosciences. xarray is based on the netCDF data model, so netCDF files -on disk directly correspond to :py:class:`~xarray.Dataset` objects (more accurately, -a group in a netCDF file directly corresponds to a to :py:class:`~xarray.Dataset` object. +on disk directly correspond to :py:class:`Dataset` objects (more accurately, +a group in a netCDF file directly corresponds to a to :py:class:`Dataset` object. See :ref:`io.netcdf_groups` for more.) NetCDF is supported on almost all platforms, and parsers exist @@ -47,7 +48,7 @@ read/write netCDF V4 files and use the compression options described below). __ https://github.com/Unidata/netcdf4-python We can save a Dataset to disk using the -:py:meth:`~Dataset.to_netcdf` method: +:py:meth:`Dataset.to_netcdf` method: .. ipython:: python @@ -65,13 +66,13 @@ the ``format`` and ``engine`` arguments. .. tip:: Using the `h5netcdf `_ package - by passing ``engine='h5netcdf'`` to :py:meth:`~xarray.open_dataset` can + by passing ``engine='h5netcdf'`` to :py:meth:`open_dataset` can sometimes be quicker than the default ``engine='netcdf4'`` that uses the `netCDF4 `_ package. We can load netCDF files to create a new Dataset using -:py:func:`~xarray.open_dataset`: +:py:func:`open_dataset`: .. ipython:: python @@ -79,9 +80,9 @@ We can load netCDF files to create a new Dataset using ds_disk Similarly, a DataArray can be saved to disk using the -:py:attr:`DataArray.to_netcdf ` method, and loaded -from disk using the :py:func:`~xarray.open_dataarray` function. As netCDF files -correspond to :py:class:`~xarray.Dataset` objects, these functions internally +:py:meth:`DataArray.to_netcdf` method, and loaded +from disk using the :py:func:`open_dataarray` function. As netCDF files +correspond to :py:class:`Dataset` objects, these functions internally convert the ``DataArray`` to a ``Dataset`` before saving, and then convert back when loading, ensuring that the ``DataArray`` that is loaded is always exactly the same as the one that was saved. @@ -108,9 +109,9 @@ is modified: the original file on disk is never touched. xarray's lazy loading of remote or on-disk datasets is often but not always desirable. Before performing computationally intense operations, it is often a good idea to load a Dataset (or DataArray) entirely into memory by - invoking the :py:meth:`~xarray.Dataset.load` method. + invoking the :py:meth:`Dataset.load` method. -Datasets have a :py:meth:`~xarray.Dataset.close` method to close the associated +Datasets have a :py:meth:`Dataset.close` method to close the associated netCDF file. However, it's often cleaner to use a ``with`` statement: .. ipython:: python @@ -135,17 +136,17 @@ to the original netCDF file, regardless if they exist in the original dataset. Groups ~~~~~~ -NetCDF groups are not supported as part of the :py:class:`~xarray.Dataset` data model. +NetCDF groups are not supported as part of the :py:class:`Dataset` data model. Instead, groups can be loaded individually as Dataset objects. To do so, pass a ``group`` keyword argument to the -:py:func:`~xarray.open_dataset` function. The group can be specified as a path-like +:py:func:`open_dataset` function. The group can be specified as a path-like string, e.g., to access subgroup ``'bar'`` within group ``'foo'`` pass ``'/foo/bar'`` as the ``group`` argument. In a similar way, the ``group`` keyword argument can be given to the -:py:meth:`~xarray.Dataset.to_netcdf` method to write to a group +:py:meth:`Dataset.to_netcdf` method to write to a group in a netCDF file. When writing multiple groups in one file, pass ``mode='a'`` to -:py:meth:`~xarray.Dataset.to_netcdf` to ensure that each call does not delete the file. +:py:meth:`Dataset.to_netcdf` to ensure that each call does not delete the file. .. _io.encoding: @@ -155,7 +156,7 @@ Reading encoded data NetCDF files follow some conventions for encoding datetime arrays (as numbers with a "units" attribute) and for packing and unpacking data (as described by the "scale_factor" and "add_offset" attributes). If the argument -``decode_cf=True`` (default) is given to :py:func:`~xarray.open_dataset`, xarray will attempt +``decode_cf=True`` (default) is given to :py:func:`open_dataset`, xarray will attempt to automatically decode the values in the netCDF objects according to `CF conventions`_. Sometimes this will fail, for example, if a variable has an invalid "units" or "calendar" attribute. For these cases, you can @@ -164,8 +165,8 @@ turn this decoding off manually. .. _CF conventions: http://cfconventions.org/ You can view this encoding information (among others) in the -:py:attr:`DataArray.encoding ` and -:py:attr:`DataArray.encoding ` attributes: +:py:attr:`DataArray.encoding` and +:py:attr:`DataArray.encoding` attributes: .. ipython:: :verbatim: @@ -206,13 +207,13 @@ Reading multi-file datasets NetCDF files are often encountered in collections, e.g., with different files corresponding to different model runs or one file per timestamp. xarray can straightforwardly combine such files into a single Dataset by making use of -:py:func:`~xarray.concat`, :py:func:`~xarray.merge`, :py:func:`~xarray.combine_nested` and -:py:func:`~xarray.combine_by_coords`. For details on the difference between these +:py:func:`concat`, :py:func:`merge`, :py:func:`combine_nested` and +:py:func:`combine_by_coords`. For details on the difference between these functions see :ref:`combining data`. Xarray includes support for manipulating datasets that don't fit into memory with dask_. If you have dask installed, you can open multiple files -simultaneously in parallel using :py:func:`~xarray.open_mfdataset`:: +simultaneously in parallel using :py:func:`open_mfdataset`:: xr.open_mfdataset('my/files/*.nc', parallel=True) @@ -221,7 +222,7 @@ single xarray dataset. It is the recommended way to open multiple files with xarray. For more details on parallel reading, see :ref:`combining.multi`, :ref:`dask.io` and a `blog post`_ by Stephan Hoyer. -:py:func:`~xarray.open_mfdataset` takes many kwargs that allow you to +:py:func:`open_mfdataset` takes many kwargs that allow you to control its behaviour (for e.g. ``parallel``, ``combine``, ``compat``, ``join``, ``concat_dim``). See its docstring for more details. @@ -246,14 +247,14 @@ See its docstring for more details. .. _dask: http://dask.pydata.org .. _blog post: http://stephanhoyer.com/2015/06/11/xray-dask-out-of-core-labeled-arrays/ -Sometimes multi-file datasets are not conveniently organized for easy use of :py:func:`~xarray.open_mfdataset`. +Sometimes multi-file datasets are not conveniently organized for easy use of :py:func:`open_mfdataset`. One can use the ``preprocess`` argument to provide a function that takes a dataset and returns a modified Dataset. -:py:func:`~xarray.open_mfdataset` will call ``preprocess`` on every dataset +:py:func:`open_mfdataset` will call ``preprocess`` on every dataset (corresponding to each file) prior to combining them. -If :py:func:`~xarray.open_mfdataset` does not meet your needs, other approaches are possible. +If :py:func:`open_mfdataset` does not meet your needs, other approaches are possible. The general pattern for parallel reading of multiple files using dask, modifying those datasets and then combining into a single ``Dataset`` is:: @@ -437,17 +438,31 @@ like ``'days'`` for ``timedelta64`` data. ``calendar`` should be one of the cale supported by netCDF4-python: 'standard', 'gregorian', 'proleptic_gregorian' 'noleap', '365_day', '360_day', 'julian', 'all_leap', '366_day'. -By default, xarray uses the 'proleptic_gregorian' calendar and units of the smallest time +By default, xarray uses the ``'proleptic_gregorian'`` calendar and units of the smallest time difference between values, with a reference time of the first time value. + +.. _io.coordinates: + +Coordinates +........... + +You can control the ``coordinates`` attribute written to disk by specifying ``DataArray.encoding["coordinates"]``. +If not specified, xarray automatically sets ``DataArray.encoding["coordinates"]`` to a space-delimited list +of names of coordinate variables that share dimensions with the ``DataArray`` being written. +This allows perfect roundtripping of xarray datasets but may not be desirable. +When an xarray ``Dataset`` contains non-dimensional coordinates that do not share dimensions with any of +the variables, these coordinate variable names are saved under a "global" ``"coordinates"`` attribute. +This is not CF-compliant but again facilitates roundtripping of xarray datasets. + Invalid netCDF files ~~~~~~~~~~~~~~~~~~~~ The library ``h5netcdf`` allows writing some dtypes (booleans, complex, ...) that aren't allowed in netCDF4 (see -`h5netcdf documentation `_. -This feature is availabe through :py:func:`DataArray.to_netcdf` and -:py:func:`Dataset.to_netcdf` when used with ``engine="h5netcdf"`` +`h5netcdf documentation `_). +This feature is availabe through :py:meth:`DataArray.to_netcdf` and +:py:meth:`Dataset.to_netcdf` when used with ``engine="h5netcdf"`` and currently raises a warning unless ``invalid_netcdf=True`` is set: .. ipython:: python @@ -480,7 +495,7 @@ The Iris_ tool allows easy reading of common meteorological and climate model fo (including GRIB and UK MetOffice PP files) into ``Cube`` objects which are in many ways very similar to ``DataArray`` objects, while enforcing a CF-compliant data model. If iris is installed xarray can convert a ``DataArray`` into a ``Cube`` using -:py:meth:`~xarray.DataArray.to_iris`: +:py:meth:`DataArray.to_iris`: .. ipython:: python @@ -492,7 +507,7 @@ installed xarray can convert a ``DataArray`` into a ``Cube`` using cube Conversely, we can create a new ``DataArray`` object from a ``Cube`` using -:py:meth:`~xarray.DataArray.from_iris`: +:py:meth:`DataArray.from_iris`: .. ipython:: python @@ -594,7 +609,7 @@ over the network until we look at particular values: .. image:: _static/opendap-prism-tmax.png Some servers require authentication before we can access the data. For this -purpose we can explicitly create a :py:class:`~xarray.backends.PydapDataStore` +purpose we can explicitly create a :py:class:`backends.PydapDataStore` and pass in a `Requests`__ session object. For example for HTTP Basic authentication:: @@ -657,8 +672,8 @@ this version of xarray will work in future versions. When pickling an object opened from a NetCDF file, the pickle file will contain a reference to the file on disk. If you want to store the actual - array values, load it into memory first with :py:meth:`~xarray.Dataset.load` - or :py:meth:`~xarray.Dataset.compute`. + array values, load it into memory first with :py:meth:`Dataset.load` + or :py:meth:`Dataset.compute`. .. _dictionary io: @@ -666,7 +681,7 @@ Dictionary ---------- We can convert a ``Dataset`` (or a ``DataArray``) to a dict using -:py:meth:`~xarray.Dataset.to_dict`: +:py:meth:`Dataset.to_dict`: .. ipython:: python @@ -674,7 +689,7 @@ We can convert a ``Dataset`` (or a ``DataArray``) to a dict using d We can create a new xarray object from a dict using -:py:meth:`~xarray.Dataset.from_dict`: +:py:meth:`Dataset.from_dict`: .. ipython:: python @@ -709,7 +724,7 @@ Rasterio GeoTIFFs and other gridded raster datasets can be opened using `rasterio`_, if rasterio is installed. Here is an example of how to use -:py:func:`~xarray.open_rasterio` to read one of rasterio's `test files`_: +:py:func:`open_rasterio` to read one of rasterio's `test files`_: .. ipython:: :verbatim: @@ -768,8 +783,7 @@ Xarray's Zarr backend allows xarray to leverage these capabilities. Xarray can't open just any zarr dataset, because xarray requires special metadata (attributes) describing the dataset dimensions and coordinates. At this time, xarray can only open zarr datasets that have been written by -xarray. To write a dataset with zarr, we use the -:py:attr:`Dataset.to_zarr ` method. +xarray. To write a dataset with zarr, we use the :py:attr:`Dataset.to_zarr` method. To write to a local directory, we pass a path to a directory .. ipython:: python @@ -816,7 +830,7 @@ can be omitted as it will internally be set to ``'a'``. To store variable length strings use ``dtype=object``. To read back a zarr dataset that has been created this way, we use the -:py:func:`~xarray.open_zarr` method: +:py:func:`open_zarr` method: .. ipython:: python @@ -885,12 +899,12 @@ opening the store. (For more information on this feature, consult the If you have zarr version 2.3 or greater, xarray can write and read stores with consolidated metadata. To write consolidated metadata, pass the ``consolidated=True`` option to the -:py:attr:`Dataset.to_zarr ` method:: +:py:attr:`Dataset.to_zarr` method:: ds.to_zarr('foo.zarr', consolidated=True) To read a consolidated store, pass the ``consolidated=True`` option to -:py:func:`~xarray.open_zarr`:: +:py:func:`open_zarr`:: ds = xr.open_zarr('foo.zarr', consolidated=True) @@ -912,7 +926,7 @@ GRIB format via cfgrib xarray supports reading GRIB files via ECMWF cfgrib_ python driver and ecCodes_ C-library, if they are installed. To open a GRIB file supply ``engine='cfgrib'`` -to :py:func:`~xarray.open_dataset`: +to :py:func:`open_dataset`: .. ipython:: :verbatim: @@ -934,7 +948,7 @@ Formats supported by PyNIO xarray can also read GRIB, HDF4 and other file formats supported by PyNIO_, if PyNIO is installed. To use PyNIO to read such files, supply -``engine='pynio'`` to :py:func:`~xarray.open_dataset`. +``engine='pynio'`` to :py:func:`open_dataset`. We recommend installing PyNIO via conda:: @@ -956,7 +970,7 @@ identify readers heuristically, or format can be specified via a key in `backend_kwargs`. To use PseudoNetCDF to read such files, supply -``engine='pseudonetcdf'`` to :py:func:`~xarray.open_dataset`. +``engine='pseudonetcdf'`` to :py:func:`open_dataset`. Add ``backend_kwargs={'format': ''}`` where `` options are listed on the PseudoNetCDF page. diff --git a/doc/pandas.rst b/doc/pandas.rst index 72abf6609f6..b1660e48dd2 100644 --- a/doc/pandas.rst +++ b/doc/pandas.rst @@ -1,3 +1,4 @@ +.. currentmodule:: xarray .. _pandas: =================== @@ -11,7 +12,7 @@ using the visualization `built in to pandas itself`__ or provided by the pandas aware libraries such as `Seaborn`__. __ http://pandas.pydata.org/pandas-docs/stable/visualization.html -__ http://stanford.edu/~mwaskom/software/seaborn/ +__ http://seaborn.pydata.org/ .. ipython:: python :suppress: @@ -32,9 +33,9 @@ Tabular data is easiest to work with when it meets the criteria for __ http://www.jstatsoft.org/v59/i10/ -In this "tidy data" format, we can represent any :py:class:`~xarray.Dataset` and -:py:class:`~xarray.DataArray` in terms of :py:class:`pandas.DataFrame` and -:py:class:`pandas.Series`, respectively (and vice-versa). The representation +In this "tidy data" format, we can represent any :py:class:`Dataset` and +:py:class:`DataArray` in terms of :py:class:`~pandas.DataFrame` and +:py:class:`~pandas.Series`, respectively (and vice-versa). The representation works by flattening non-coordinates to 1D, and turning the tensor product of coordinate indexes into a :py:class:`pandas.MultiIndex`. @@ -42,7 +43,7 @@ Dataset and DataFrame --------------------- To convert any dataset to a ``DataFrame`` in tidy form, use the -:py:meth:`Dataset.to_dataframe() ` method: +:py:meth:`Dataset.to_dataframe()` method: .. ipython:: python @@ -61,11 +62,11 @@ use ``DataFrame`` methods like :py:meth:`~pandas.DataFrame.reset_index`, :py:meth:`~pandas.DataFrame.stack` and :py:meth:`~pandas.DataFrame.unstack`. For datasets containing dask arrays where the data should be lazily loaded, see the -:py:meth:`Dataset.to_dask_dataframe() ` method. +:py:meth:`Dataset.to_dask_dataframe()` method. To create a ``Dataset`` from a ``DataFrame``, use the -:py:meth:`~xarray.Dataset.from_dataframe` class method or the equivalent -:py:meth:`pandas.DataFrame.to_xarray ` method: +:py:meth:`Dataset.from_dataframe` class method or the equivalent +:py:meth:`pandas.DataFrame.to_xarray` method: .. ipython:: python @@ -83,7 +84,7 @@ DataArray and Series -------------------- ``DataArray`` objects have a complementary representation in terms of a -:py:class:`pandas.Series`. Using a Series preserves the ``Dataset`` to +:py:class:`~pandas.Series`. Using a Series preserves the ``Dataset`` to ``DataArray`` relationship, because ``DataFrames`` are dict-like containers of ``Series``. The methods are very similar to those for working with DataFrames: @@ -109,7 +110,7 @@ Multi-dimensional data Tidy data is great, but it sometimes you want to preserve dimensions instead of automatically stacking them into a ``MultiIndex``. -:py:meth:`DataArray.to_pandas() ` is a shortcut that +:py:meth:`DataArray.to_pandas()` is a shortcut that lets you convert a DataArray directly into a pandas object with the same dimensionality (i.e., a 1D array is converted to a :py:class:`~pandas.Series`, 2D to :py:class:`~pandas.DataFrame` and 3D to ``pandas.Panel``): @@ -122,7 +123,7 @@ dimensionality (i.e., a 1D array is converted to a :py:class:`~pandas.Series`, df To perform the inverse operation of converting any pandas objects into a data -array with the same shape, simply use the :py:class:`~xarray.DataArray` +array with the same shape, simply use the :py:class:`DataArray` constructor: .. ipython:: python @@ -143,7 +144,7 @@ preserve all use of multi-indexes: However, you will need to set dimension names explicitly, either with the ``dims`` argument on in the ``DataArray`` constructor or by calling -:py:class:`~xarray.Dataset.rename` on the new object. +:py:class:`~Dataset.rename` on the new object. .. _panel transition: diff --git a/doc/plotting.rst b/doc/plotting.rst index 270988b99de..ea9816780a7 100644 --- a/doc/plotting.rst +++ b/doc/plotting.rst @@ -1,3 +1,4 @@ +.. currentmodule:: xarray .. _plotting: Plotting @@ -10,8 +11,8 @@ Labeled data enables expressive computations. These same labels can also be used to easily create informative plots. xarray's plotting capabilities are centered around -:py:class:`xarray.DataArray` objects. -To plot :py:class:`xarray.Dataset` objects +:py:class:`DataArray` objects. +To plot :py:class:`Dataset` objects simply access the relevant DataArrays, ie ``dset['var1']``. Dataset specific plotting routines are also available (see :ref:`plot-dataset`). Here we focus mostly on arrays 2d or larger. If your data fits @@ -94,7 +95,7 @@ One Dimension Simple Example ================ -The simplest way to make a plot is to call the :py:func:`xarray.DataArray.plot()` method. +The simplest way to make a plot is to call the :py:func:`DataArray.plot()` method. .. ipython:: python @@ -227,7 +228,7 @@ It is required to explicitly specify either Thus, we could have made the previous plot by specifying ``hue='lat'`` instead of ``x='time'``. If required, the automatic legend can be turned off using ``add_legend=False``. Alternatively, -``hue`` can be passed directly to :py:func:`xarray.plot` as `air.isel(lon=10, lat=[19,21,22]).plot(hue='lat')`. +``hue`` can be passed directly to :py:func:`xarray.plot.line` as `air.isel(lon=10, lat=[19,21,22]).plot.line(hue='lat')`. ======================== @@ -256,7 +257,7 @@ made using 1D data. The argument ``where`` defines where the steps should be placed, options are ``'pre'`` (default), ``'post'``, and ``'mid'``. This is particularly handy -when plotting data grouped with :py:func:`xarray.Dataset.groupby_bins`. +when plotting data grouped with :py:meth:`Dataset.groupby_bins`. .. ipython:: python @@ -295,7 +296,7 @@ Two Dimensions Simple Example ================ -The default method :py:meth:`xarray.DataArray.plot` calls :py:func:`xarray.plot.pcolormesh` by default when the data is two-dimensional. +The default method :py:meth:`DataArray.plot` calls :py:func:`xarray.plot.pcolormesh` by default when the data is two-dimensional. .. ipython:: python @@ -487,6 +488,7 @@ Faceting here refers to splitting an array along one or two dimensions and plotting each group. xarray's basic plotting is useful for plotting two dimensional arrays. What about three or four dimensional arrays? That's where facets become helpful. +The general approach to plotting here is called “small multiples”, where the same kind of plot is repeated multiple times, and the specific use of small multiples to display the same relationship conditioned on one ore more other variables is often called a “trellis plot”. Consider the temperature data set. There are 4 observations per day for two years which makes for 2920 values along the time dimension. @@ -572,8 +574,9 @@ Faceted plotting supports other arguments common to xarray 2d plots. FacetGrid Objects =================== -:py:class:`xarray.plot.FacetGrid` is used to control the behavior of the -multiple plots. +The object returned, ``g`` in the above examples, is a :py:class:`~xarray.plot.FacetGrid` object +that links a :py:class:`DataArray` to a matplotlib figure with a particular structure. +This object can be used to control the behavior of the multiple plots. It borrows an API and code from `Seaborn's FacetGrid `_. The structure is contained within the ``axes`` and ``name_dicts`` @@ -609,6 +612,13 @@ they have been plotted. @savefig plot_facet_iterator.png plt.draw() + +:py:class:`~xarray.plot.FacetGrid` objects have methods that let you customize the automatically generated +axis labels, axis ticks and plot titles. See :py:meth:`~xarray.plot.FacetGrid.set_titles`, +:py:meth:`~xarray.plot.FacetGrid.set_xlabels`, :py:meth:`~xarray.plot.FacetGrid.set_ylabels` and +:py:meth:`~xarray.plot.FacetGrid.set_ticks` for more information. +Plotting functions can be applied to each subset of the data by calling :py:meth:`~xarray.plot.FacetGrid.map_dataarray` or to each subplot by calling :py:meth:`~xarray.plot.FacetGrid.map`. + TODO: add an example of using the ``map`` method to plot dataset variables (e.g., with ``plt.quiver``). diff --git a/doc/related-projects.rst b/doc/related-projects.rst index a8af05f3074..3188751366f 100644 --- a/doc/related-projects.rst +++ b/doc/related-projects.rst @@ -25,6 +25,7 @@ Geosciences - `PyGDX `_: Python 3 package for accessing data stored in GAMS Data eXchange (GDX) files. Also uses a custom subclass. +- `pyinterp `_: Python 3 package for interpolating geo-referenced data used in the field of geosciences. - `pyXpcm `_: xarray-based Profile Classification Modelling (PCM), mostly for ocean data. - `Regionmask `_: plotting and creation of masks of spatial regions - `rioxarray `_: geospatial xarray extension powered by rasterio diff --git a/doc/terminology.rst b/doc/terminology.rst index d1265e4da9d..ab6d856920a 100644 --- a/doc/terminology.rst +++ b/doc/terminology.rst @@ -1,3 +1,4 @@ +.. currentmodule:: xarray .. _terminology: Terminology diff --git a/doc/time-series.rst b/doc/time-series.rst index 1cb535ea886..d838dbbd4cd 100644 --- a/doc/time-series.rst +++ b/doc/time-series.rst @@ -222,4 +222,4 @@ Data that has indices outside of the given ``tolerance`` are set to ``NaN``. For more examples of using grouped operations on a time dimension, see -:ref:`toy weather data`. +:doc:`examples/weather-data`. diff --git a/doc/whats-new.rst b/doc/whats-new.rst index d4d8ab8f3e5..1d7c425e554 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -13,59 +13,209 @@ What's New import xarray as xr np.random.seed(123456) -.. _whats-new.0.15.0: +.. _whats-new.0.15.1: -v0.15.0 (unreleased) --------------------- +v0.15.1 (unreleased) +--------------------- Breaking changes ~~~~~~~~~~~~~~~~ +New Features +~~~~~~~~~~~~ + +- Support new h5netcdf backend keyword `phony_dims` (available from h5netcdf + v0.8.0 for :py:class:`~xarray.backends.H5NetCDFStore`. + By `Kai Mühlbauer `_. +- implement pint support. (:issue:`3594`, :pull:`3706`) + By `Justus Magin `_. + + +Bug fixes +~~~~~~~~~ +- Use ``dask_array_type`` instead of ``dask_array.Array`` for type + checking. (:issue:`3779`, :pull:`3787`) + By `Justus Magin `_. + +- :py:func:`concat` can now handle coordinate variables only present in one of + the objects to be concatenated when ``coords="different"``. + By `Deepak Cherian `_. + +Documentation +~~~~~~~~~~~~~ + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Removed the internal ``import_seaborn`` function which handled the deprecation of + the ``seaborn.apionly`` entry point (:issue:`3747`). + By `Mathias Hauser `_. +- Changed test_open_mfdataset_list_attr to only run with dask installed + (:issue:`3777`, :pull:`3780`). + By `Bruno Pagani `_. + +.. _whats-new.0.15.0: + + +v0.15.0 (30 Jan 2020) +--------------------- + +This release brings many improvements to xarray's documentation: our examples are now binderized notebooks (`click here `_) +and we have new example notebooks from our SciPy 2019 sprint (many thanks to our contributors!). + +This release also features many API improvements such as a new +:py:class:`~core.accessor_dt.TimedeltaAccessor` and support for :py:class:`CFTimeIndex` in +:py:meth:`~DataArray.interpolate_na`); as well as many bug fixes. + +Breaking changes +~~~~~~~~~~~~~~~~ +- Bumped minimum tested versions for dependencies: + + - numpy 1.15 + - pandas 0.25 + - dask 2.2 + - distributed 2.2 + - scipy 1.3 + +- Remove ``compat`` and ``encoding`` kwargs from ``DataArray``, which + have been deprecated since 0.12. (:pull:`3650`). + Instead, specify the ``encoding`` kwarg when writing to disk or set + the :py:attr:`DataArray.encoding` attribute directly. + By `Maximilian Roos `_. +- :py:func:`xarray.dot`, :py:meth:`DataArray.dot`, and the ``@`` operator now + use ``align="inner"`` (except when ``xarray.set_options(arithmetic_join="exact")``; + :issue:`3694`) by `Mathias Hauser `_. New Features ~~~~~~~~~~~~ +- :py:meth:`DataArray.sel` and :py:meth:`Dataset.sel` now support :py:class:`pandas.CategoricalIndex`. (:issue:`3669`) + By `Keisuke Fujii `_. +- Support using an existing, opened h5netcdf ``File`` with + :py:class:`~xarray.backends.H5NetCDFStore`. This permits creating an + :py:class:`~xarray.Dataset` from a h5netcdf ``File`` that has been opened + using other means (:issue:`3618`). + By `Kai Mühlbauer `_. +- Implement ``median`` and ``nanmedian`` for dask arrays. This works by rechunking + to a single chunk along all reduction axes. (:issue:`2999`). + By `Deepak Cherian `_. +- :py:func:`~xarray.concat` now preserves attributes from the first Variable. + (:issue:`2575`, :issue:`2060`, :issue:`1614`) + By `Deepak Cherian `_. - :py:meth:`Dataset.quantile`, :py:meth:`DataArray.quantile` and ``GroupBy.quantile`` now work with dask Variables. By `Deepak Cherian `_. -- Added the :py:meth:`count` reduction method to both :py:class:`DatasetCoarsen` - and :py:class:`DataArrayCoarsen` objects. (:pull:`3500`) - By `Deepak Cherian `_ +- Added the ``count`` reduction method to both :py:class:`~core.rolling.DatasetCoarsen` + and :py:class:`~core.rolling.DataArrayCoarsen` objects. (:pull:`3500`) + By `Deepak Cherian `_ +- Add ``meta`` kwarg to :py:func:`~xarray.apply_ufunc`; + this is passed on to :py:func:`dask.array.blockwise`. (:pull:`3660`) + By `Deepak Cherian `_. +- Add ``attrs_file`` option in :py:func:`~xarray.open_mfdataset` to choose the + source file for global attributes in a multi-file dataset (:issue:`2382`, + :pull:`3498`). By `Julien Seguinot `_. +- :py:meth:`Dataset.swap_dims` and :py:meth:`DataArray.swap_dims` + now allow swapping to dimension names that don't exist yet. (:pull:`3636`) + By `Justus Magin `_. +- Extend :py:class:`~core.accessor_dt.DatetimeAccessor` properties + and support ``.dt`` accessor for timedeltas + via :py:class:`~core.accessor_dt.TimedeltaAccessor` (:pull:`3612`) + By `Anderson Banihirwe `_. +- Improvements to interpolating along time axes (:issue:`3641`, :pull:`3631`). + By `David Huard `_. + + - Support :py:class:`CFTimeIndex` in :py:meth:`DataArray.interpolate_na` + - define 1970-01-01 as the default offset for the interpolation index for both + :py:class:`pandas.DatetimeIndex` and :py:class:`CFTimeIndex`, + - use microseconds in the conversion from timedelta objects to floats to avoid + overflow errors. Bug fixes ~~~~~~~~~ +- Applying a user-defined function that adds new dimensions using :py:func:`apply_ufunc` + and ``vectorize=True`` now works with ``dask > 2.0``. (:issue:`3574`, :pull:`3660`). + By `Deepak Cherian `_. +- Fix :py:meth:`~xarray.combine_by_coords` to allow for combining incomplete + hypercubes of Datasets (:issue:`3648`). By `Ian Bolliger + `_. +- Fix :py:func:`~xarray.combine_by_coords` when combining cftime coordinates + which span long time intervals (:issue:`3535`). By `Spencer Clark + `_. - Fix plotting with transposed 2D non-dimensional coordinates. (:issue:`3138`, :pull:`3441`) By `Deepak Cherian `_. +- :py:meth:`plot.FacetGrid.set_titles` can now replace existing row titles of a + :py:class:`~xarray.plot.FacetGrid` plot. In addition :py:class:`~xarray.plot.FacetGrid` gained + two new attributes: :py:attr:`~xarray.plot.FacetGrid.col_labels` and + :py:attr:`~xarray.plot.FacetGrid.row_labels` contain :py:class:`matplotlib.text.Text` handles for both column and + row labels. These can be used to manually change the labels. + By `Deepak Cherian `_. +- Fix issue with Dask-backed datasets raising a ``KeyError`` on some computations involving :py:func:`map_blocks` (:pull:`3598`). + By `Tom Augspurger `_. +- Ensure :py:meth:`Dataset.quantile`, :py:meth:`DataArray.quantile` issue the correct error + when ``q`` is out of bounds (:issue:`3634`) by `Mathias Hauser `_. +- Fix regression in xarray 0.14.1 that prevented encoding times with certain + ``dtype``, ``_FillValue``, and ``missing_value`` encodings (:issue:`3624`). + By `Spencer Clark `_ +- Raise an error when trying to use :py:meth:`Dataset.rename_dims` to + rename to an existing name (:issue:`3438`, :pull:`3645`) + By `Justus Magin `_. +- :py:meth:`Dataset.rename`, :py:meth:`DataArray.rename` now check for conflicts with + MultiIndex level names. +- :py:meth:`Dataset.merge` no longer fails when passed a :py:class:`DataArray` instead of a :py:class:`Dataset`. + By `Tom Nicholas `_. +- Fix a regression in :py:meth:`Dataset.drop`: allow passing any + iterable when dropping variables (:issue:`3552`, :pull:`3693`) + By `Justus Magin `_. +- Fixed errors emitted by ``mypy --strict`` in modules that import xarray. + (:issue:`3695`) by `Guido Imperiale `_. +- Allow plotting of binned coordinates on the y axis in :py:meth:`plot.line` + and :py:meth:`plot.step` plots (:issue:`3571`, + :pull:`3685`) by `Julien Seguinot `_. +- setuptools is now marked as a dependency of xarray + (:pull:`3628`) by `Richard Höchenberger `_. Documentation ~~~~~~~~~~~~~ -- Switch doc examples to use nbsphinx and replace sphinx_gallery with - notebook. - (:pull:`3105`, :pull:`3106`, :pull:`3121`) - By `Ryan Abernathey ` -- Added example notebook demonstrating use of xarray with Regional Ocean - Modeling System (ROMS) ocean hydrodynamic model output. - (:pull:`3116`). - By `Robert Hetland ` -- Added example notebook demonstrating the visualization of ERA5 GRIB - data. (:pull:`3199`) - By `Zach Bruick ` and - `Stephan Siemen ` -- Added examples for `DataArray.quantile`, `Dataset.quantile` and - `GroupBy.quantile`. (:pull:`3576`) +- Switch doc examples to use `nbsphinx `_ and replace + ``sphinx_gallery`` scripts with Jupyter notebooks. (:pull:`3105`, :pull:`3106`, :pull:`3121`) + By `Ryan Abernathey `_. +- Added :doc:`example notebook ` demonstrating use of xarray with + Regional Ocean Modeling System (ROMS) ocean hydrodynamic model output. (:pull:`3116`) + By `Robert Hetland `_. +- Added :doc:`example notebook ` demonstrating the visualization of + ERA5 GRIB data. (:pull:`3199`) + By `Zach Bruick `_ and + `Stephan Siemen `_. +- Added examples for :py:meth:`DataArray.quantile`, :py:meth:`Dataset.quantile` and + ``GroupBy.quantile``. (:pull:`3576`) By `Justus Magin `_. +- Add new :doc:`example notebook ` example notebook demonstrating + vectorization of a 1D function using :py:func:`apply_ufunc` , dask and numba. + By `Deepak Cherian `_. +- Added example for :py:func:`~xarray.map_blocks`. (:pull:`3667`) + By `Riley X. Brady `_. Internal Changes ~~~~~~~~~~~~~~~~ +- Make sure dask names change when rechunking by different chunk sizes. Conversely, make sure they + stay the same when rechunking by the same chunk size. (:issue:`3350`) + By `Deepak Cherian `_. - 2x to 5x speed boost (on small arrays) for :py:meth:`Dataset.isel`, :py:meth:`DataArray.isel`, and :py:meth:`DataArray.__getitem__` when indexing by int, slice, list of int, scalar ndarray, or 1-dimensional ndarray. (:pull:`3533`) by `Guido Imperiale `_. -- Removed internal method ``Dataset._from_vars_and_coord_names``, +- Removed internal method ``Dataset._from_vars_and_coord_names``, which was dominated by ``Dataset._construct_direct``. (:pull:`3565`) - By `Maximilian Roos `_ + By `Maximilian Roos `_. +- Replaced versioneer with setuptools-scm. Moved contents of setup.py to setup.cfg. + Removed pytest-runner from setup.py, as per deprecation notice on the pytest-runner + project. (:pull:`3714`) by `Guido Imperiale `_. +- Use of isort is now enforced by CI. + (:pull:`3721`) by `Guido Imperiale `_ +.. _whats-new.0.14.1: + v0.14.1 (19 Nov 2019) --------------------- @@ -88,8 +238,8 @@ Breaking changes New Features ~~~~~~~~~~~~ -- Added the ``sparse`` option to :py:meth:`~xarray.DataArray.unstack`, - :py:meth:`~xarray.Dataset.unstack`, :py:meth:`~xarray.DataArray.reindex`, +- Added the ``sparse`` option to :py:meth:`~xarray.DataArray.unstack`, + :py:meth:`~xarray.Dataset.unstack`, :py:meth:`~xarray.DataArray.reindex`, :py:meth:`~xarray.Dataset.reindex` (:issue:`3518`). By `Keisuke Fujii `_. - Added the ``fill_value`` option to :py:meth:`DataArray.unstack` and @@ -99,13 +249,13 @@ New Features :py:meth:`~xarray.Dataset.interpolate_na`. This controls the maximum size of the data gap that will be filled by interpolation. By `Deepak Cherian `_. - Added :py:meth:`Dataset.drop_sel` & :py:meth:`DataArray.drop_sel` for dropping labels. - :py:meth:`Dataset.drop_vars` & :py:meth:`DataArray.drop_vars` have been added for + :py:meth:`Dataset.drop_vars` & :py:meth:`DataArray.drop_vars` have been added for dropping variables (including coordinates). The existing :py:meth:`Dataset.drop` & :py:meth:`DataArray.drop` methods remain as a backward compatible option for dropping either labels or variables, but using the more specific methods is encouraged. (:pull:`3475`) By `Maximilian Roos `_ -- Added :py:meth:`Dataset.map` & :py:meth:`GroupBy.map` & :py:meth:`Resample.map` for +- Added :py:meth:`Dataset.map` & ``GroupBy.map`` & ``Resample.map`` for mapping / applying a function over each item in the collection, reflecting the widely used and least surprising name for this operation. The existing ``apply`` methods remain for backward compatibility, though using the ``map`` @@ -124,7 +274,7 @@ New Features - :py:func:`xarray.dot`, and :py:meth:`DataArray.dot` now support the ``dims=...`` option to sum over the union of dimensions of all input arrays (:issue:`3423`) by `Mathias Hauser `_. -- Added new :py:meth:`Dataset._repr_html_` and :py:meth:`DataArray._repr_html_` to improve +- Added new ``Dataset._repr_html_`` and ``DataArray._repr_html_`` to improve representation of objects in Jupyter. By default this feature is turned off for now. Enable it with ``xarray.set_options(display_style="html")``. (:pull:`3425`) by `Benoit Bovy `_ and @@ -133,22 +283,26 @@ New Features `_ for xarray objects. Note that xarray objects with a dask.array backend already used deterministic hashing in previous releases; this change implements it when whole - xarray objects are embedded in a dask graph, e.g. when :py:meth:`DataArray.map` is + xarray objects are embedded in a dask graph, e.g. when :py:meth:`DataArray.map_blocks` is invoked. (:issue:`3378`, :pull:`3446`, :pull:`3515`) By `Deepak Cherian `_ and `Guido Imperiale `_. -- Add the documented-but-missing :py:meth:`DatasetGroupBy.quantile`. +- Add the documented-but-missing :py:meth:`~core.groupby.DatasetGroupBy.quantile`. +- xarray now respects the ``DataArray.encoding["coordinates"]`` attribute when writing to disk. + See :ref:`io.coordinates` for more. (:issue:`3351`, :pull:`3487`) + By `Deepak Cherian `_. +- Add the documented-but-missing :py:meth:`~core.groupby.DatasetGroupBy.quantile`. (:issue:`3525`, :pull:`3527`). By `Justus Magin `_. Bug fixes ~~~~~~~~~ -- Ensure an index of type ``CFTimeIndex`` is not converted to a ``DatetimeIndex`` when +- Ensure an index of type ``CFTimeIndex`` is not converted to a ``DatetimeIndex`` when calling :py:meth:`Dataset.rename`, :py:meth:`Dataset.rename_dims` and :py:meth:`Dataset.rename_vars`. By `Mathias Hauser `_. (:issue:`3522`). - Fix a bug in :py:meth:`DataArray.set_index` in case that an existing dimension becomes a level variable of MultiIndex. (:pull:`3520`). By `Keisuke Fujii `_. - Harmonize ``_FillValue``, ``missing_value`` during encoding and decoding steps. (:pull:`3502`) - By `Anderson Banihirwe `_. + By `Anderson Banihirwe `_. - Fix regression introduced in v0.14.0 that would cause a crash if dask is installed but cloudpickle isn't (:issue:`3401`) by `Rhys Doyle `_ - Fix grouping over variables with NaNs. (:issue:`2383`, :pull:`3406`). @@ -163,7 +317,7 @@ Bug fixes - Rolling reduction operations no longer compute dask arrays by default. (:issue:`3161`). In addition, the ``allow_lazy`` kwarg to ``reduce`` is deprecated. By `Deepak Cherian `_. -- Fix :py:meth:`GroupBy.reduce` when reducing over multiple dimensions. +- Fix ``GroupBy.reduce`` when reducing over multiple dimensions. (:issue:`3402`). By `Deepak Cherian `_ - Allow appending datetime and bool data variables to zarr stores. (:issue:`3480`). By `Akihiro Matsukawa `_. @@ -213,7 +367,7 @@ Internal Changes - Enable type checking on default sentinel values (:pull:`3472`) By `Maximilian Roos `_ -- Add :py:meth:`Variable._replace` for simpler replacing of a subset of attributes (:pull:`3472`) +- Add ``Variable._replace`` for simpler replacing of a subset of attributes (:pull:`3472`) By `Maximilian Roos `_ .. _whats-new.0.14.0: @@ -269,7 +423,7 @@ New functions/methods Enhancements ~~~~~~~~~~~~ -- :py:class:`~xarray.core.GroupBy` enhancements. By `Deepak Cherian `_. +- ``core.groupby.GroupBy`` enhancements. By `Deepak Cherian `_. - Added a repr (:pull:`3344`). Example:: @@ -304,7 +458,7 @@ Bug fixes - Fix error in concatenating unlabeled dimensions (:pull:`3362`). By `Deepak Cherian `_. - Warn if the ``dim`` kwarg is passed to rolling operations. This is redundant since a dimension is - specified when the :py:class:`DatasetRolling` or :py:class:`DataArrayRolling` object is created. + specified when the :py:class:`~core.rolling.DatasetRolling` or :py:class:`~core.rolling.DataArrayRolling` object is created. (:pull:`3362`). By `Deepak Cherian `_. Documentation @@ -377,7 +531,7 @@ Breaking changes - Reindexing with variables of a different dimension now raise an error (previously deprecated) - ``xarray.broadcast_array`` is removed (previously deprecated in favor of :py:func:`~xarray.broadcast`) -- :py:meth:`Variable.expand_dims` is removed (previously deprecated in favor of +- ``Variable.expand_dims`` is removed (previously deprecated in favor of :py:meth:`Variable.set_dims`) New functions/methods @@ -462,8 +616,7 @@ Enhancements - ``xarray.Dataset.drop`` now supports keyword arguments; dropping index labels by using both ``dim`` and ``labels`` or using a - :py:class:`~xarray.core.coordinates.DataArrayCoordinates` object are - deprecated (:issue:`2910`). + :py:class:`~core.coordinates.DataArrayCoordinates` object are deprecated (:issue:`2910`). By `Gregory Gundersen `_. - Added examples of :py:meth:`Dataset.set_index` and @@ -611,7 +764,7 @@ New functions/methods By `Alan Brammer `_ and `Ryan May `_. -- :py:meth:`~xarray.core.GroupBy.quantile` is now a method of ``GroupBy`` +- ``GroupBy.quantile`` is now a method of ``GroupBy`` objects (:issue:`3018`). By `David Huard `_. @@ -1153,7 +1306,7 @@ Announcements of note: for more details. - We have a new :doc:`roadmap` that outlines our future development plans. -- `Dataset.apply` now properly documents the way `func` is called. +- ``Dataset.apply`` now properly documents the way `func` is called. By `Matti Eskelinen `_. Enhancements @@ -1585,7 +1738,7 @@ Backwards incompatible changes Enhancements ~~~~~~~~~~~~ -- Added :py:func:`~xarray.dot`, equivalent to :py:func:`np.einsum`. +- Added :py:func:`~xarray.dot`, equivalent to :py:func:`numpy.einsum`. Also, :py:func:`~xarray.DataArray.dot` now supports ``dims`` option, which specifies the dimensions to sum over. (:issue:`1951`) @@ -1653,7 +1806,7 @@ Documentation - Added a new guide on :ref:`contributing` (:issue:`640`) By `Joe Hamman `_. -- Added apply_ufunc example to :ref:`toy weather data` (:issue:`1844`). +- Added apply_ufunc example to :ref:`/examples/weather-data.ipynb#Toy-weather-data` (:issue:`1844`). By `Liam Brannigan `_. - New entry `Why don’t aggregations return Python scalars?` in the :doc:`faq` (:issue:`1726`). @@ -1770,7 +1923,7 @@ Bug fixes coordinates of target, destination and keys. If there are any conflict among these coordinates, ``IndexError`` will be raised. By `Keisuke Fujii `_. -- Properly point :py:meth:`DataArray.__dask_scheduler__` to +- Properly point ``DataArray.__dask_scheduler__`` to ``dask.threaded.get``. By `Matthew Rocklin `_. - Bug fixes in :py:meth:`DataArray.plot.imshow`: all-NaN arrays and arrays with size one in some dimension can now be plotted, which is good for @@ -1982,7 +2135,7 @@ Enhancements - Support for :py:class:`pathlib.Path` objects added to :py:func:`~xarray.open_dataset`, :py:func:`~xarray.open_mfdataset`, - :py:func:`~xarray.to_netcdf`, and :py:func:`~xarray.save_mfdataset` + ``xarray.to_netcdf``, and :py:func:`~xarray.save_mfdataset` (:issue:`799`): .. ipython:: @@ -2390,7 +2543,7 @@ Enhancements By `Stephan Hoyer `_ and `Phillip J. Wolfram `_. -- New aggregation on rolling objects :py:meth:`DataArray.rolling(...).count()` +- New aggregation on rolling objects :py:meth:`~core.rolling.DataArrayRolling.count` which providing a rolling count of valid values (:issue:`1138`). Bug fixes @@ -3561,7 +3714,7 @@ Enhancements ``fillna`` works on both ``Dataset`` and ``DataArray`` objects, and uses index based alignment and broadcasting like standard binary operations. It also can be applied by group, as illustrated in - :ref:`fill with climatology`. + :ref:`/examples/weather-data.ipynb#Fill-missing-values-with-climatology`. - New ``xray.Dataset.assign`` and ``xray.Dataset.assign_coords`` methods patterned off the new :py:meth:`DataFrame.assign ` method in pandas: diff --git a/licenses/DASK_LICENSE b/licenses/DASK_LICENSE index 893bddfb933..e98784cd600 100644 --- a/licenses/DASK_LICENSE +++ b/licenses/DASK_LICENSE @@ -1,4 +1,4 @@ -:py:meth:`~xarray.DataArray.isin`Copyright (c) 2014-2018, Anaconda, Inc. and contributors +Copyright (c) 2014-2018, Anaconda, Inc. and contributors All rights reserved. Redistribution and use in source and binary forms, with or without modification, diff --git a/properties/test_pandas_roundtrip.py b/properties/test_pandas_roundtrip.py index a8005d319d6..5fc097f1f5e 100644 --- a/properties/test_pandas_roundtrip.py +++ b/properties/test_pandas_roundtrip.py @@ -1,20 +1,20 @@ """ Property-based tests for roundtripping between xarray and pandas objects. """ -import pytest - -pytest.importorskip("hypothesis") - from functools import partial -import hypothesis.extra.numpy as npst -import hypothesis.extra.pandas as pdst -import hypothesis.strategies as st -from hypothesis import given import numpy as np import pandas as pd +import pytest + import xarray as xr +pytest.importorskip("hypothesis") +import hypothesis.extra.numpy as npst # isort:skip +import hypothesis.extra.pandas as pdst # isort:skip +import hypothesis.strategies as st # isort:skip +from hypothesis import given # isort:skip + numeric_dtypes = st.one_of( npst.unsigned_integer_dtypes(), npst.integer_dtypes(), npst.floating_dtypes() ) diff --git a/readthedocs.yml b/readthedocs.yml index c64fa1b7b02..ad249bf8c09 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -1,8 +1,13 @@ +version: 2 + build: image: latest + conda: - file: ci/requirements/doc.yml + environment: ci/requirements/doc.yml + python: - version: 3.7 - setup_py_install: false + version: 3.8 + install: [] + formats: [] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000000..f73887ff5cc --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +# This file is redundant with setup.cfg; +# it exists to let GitHub build the repository dependency graph +# https://help.github.com/en/github/visualizing-repository-data-with-graphs/listing-the-packages-that-a-repository-depends-on + +numpy >= 1.15 +pandas >= 0.25 +setuptools >= 41.2 diff --git a/setup.cfg b/setup.cfg index 21158e3b0ee..42dc53bb882 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,96 @@ +[metadata] +name = xarray +author = xarray Developers +author_email = xarray@googlegroups.com +license = Apache +description = N-D labeled arrays and datasets in Python +long_description_content_type=text/x-rst +long_description = + **xarray** (formerly **xray**) is an open source project and Python package + that makes working with labelled multi-dimensional arrays simple, + efficient, and fun! + + xarray introduces labels in the form of dimensions, coordinates and + attributes on top of raw NumPy_-like arrays, which allows for a more + intuitive, more concise, and less error-prone developer experience. + The package includes a large and growing library of domain-agnostic functions + for advanced analytics and visualization with these data structures. + + xarray was inspired by and borrows heavily from pandas_, the popular data + analysis package focused on labelled tabular data. + It is particularly tailored to working with netCDF_ files, which were the + source of xarray's data model, and integrates tightly with dask_ for parallel + computing. + + .. _NumPy: https://www.numpy.org + .. _pandas: https://pandas.pydata.org + .. _dask: https://dask.org + .. _netCDF: https://www.unidata.ucar.edu/software/netcdf + + Why xarray? + ----------- + Multi-dimensional (a.k.a. N-dimensional, ND) arrays (sometimes called + "tensors") are an essential part of computational science. + They are encountered in a wide range of fields, including physics, astronomy, + geoscience, bioinformatics, engineering, finance, and deep learning. + In Python, NumPy_ provides the fundamental data structure and API for + working with raw ND arrays. + However, real-world datasets are usually more than just raw numbers; + they have labels which encode information about how the array values map + to locations in space, time, etc. + + xarray doesn't just keep track of labels on arrays -- it uses them to provide a + powerful and concise interface. For example: + + - Apply operations over dimensions by name: ``x.sum('time')``. + - Select values by label instead of integer location: ``x.loc['2014-01-01']`` or ``x.sel(time='2014-01-01')``. + - Mathematical operations (e.g., ``x - y``) vectorize across multiple dimensions (array broadcasting) based on dimension names, not shape. + - Flexible split-apply-combine operations with groupby: ``x.groupby('time.dayofyear').mean()``. + - Database like alignment based on coordinate labels that smoothly handles missing values: ``x, y = xr.align(x, y, join='outer')``. + - Keep track of arbitrary metadata in the form of a Python dictionary: ``x.attrs``. + + Learn more + ---------- + - Documentation: ``_ + - Issue tracker: ``_ + - Source code: ``_ + - SciPy2015 talk: ``_ + +url = https://github.com/pydata/xarray +classifiers = + Development Status :: 5 - Production/Stable + License :: OSI Approved :: Apache Software License + Operating System :: OS Independent + Intended Audience :: Science/Research + Programming Language :: Python + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.6 + Programming Language :: Python :: 3.7 + Topic :: Scientific/Engineering + +[options] +packages = xarray +zip_safe = False # https://mypy.readthedocs.io/en/latest/installed_packages.html +include_package_data = True +python_requires = >=3.6 +install_requires = + numpy >= 1.15 + pandas >= 0.25 + setuptools >= 41.2 # For pkg_resources +setup_requires = + setuptools >= 41.2 + setuptools_scm + +[options.package_data] +xarray = + py.typed + tests/data/* + static/css/* + static/html/* + [tool:pytest] -python_files=test_*.py -testpaths=xarray/tests properties +python_files = test_*.py +testpaths = xarray/tests properties # Fixed upstream in https://github.com/pydata/bottleneck/pull/199 filterwarnings = ignore:Using a non-tuple sequence for multidimensional indexing is deprecated:FutureWarning @@ -12,7 +102,7 @@ markers = slow: slow tests [flake8] -ignore= +ignore = # whitespace before ':' - doesn't work well with black E203 E402 @@ -23,16 +113,17 @@ ignore= # line break before binary operator W503 exclude= + .eggs doc [isort] -default_section=THIRDPARTY -known_first_party=xarray -multi_line_output=3 -include_trailing_comma=True -force_grid_wrap=0 -use_parentheses=True -line_length=88 +default_section = THIRDPARTY +known_first_party = xarray +multi_line_output = 3 +include_trailing_comma = True +force_grid_wrap = 0 +use_parentheses = True +line_length = 88 # Most of the numerical computing stack doesn't have type annotations yet. [mypy-affine.*] @@ -87,35 +178,19 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-seaborn.*] ignore_missing_imports = True +[mypy-setuptools] +ignore_missing_imports = True [mypy-sparse.*] ignore_missing_imports = True [mypy-toolz.*] ignore_missing_imports = True [mypy-zarr.*] ignore_missing_imports = True - -# setuptools is not typed -[mypy-setup] -ignore_errors = True -# versioneer code -[mypy-versioneer.*] -ignore_errors = True -# written by versioneer -[mypy-xarray._version] -ignore_errors = True # version spanning code is hard to type annotate (and most of this module will # be going away soon anyways) [mypy-xarray.core.pycompat] ignore_errors = True -[versioneer] -VCS = git -style = pep440 -versionfile_source = xarray/_version.py -versionfile_build = xarray/_version.py -tag_prefix = v -parentdir_prefix = xarray- - [aliases] test = pytest diff --git a/setup.py b/setup.py index cba0c74aa3a..76755a445f7 100755 --- a/setup.py +++ b/setup.py @@ -1,110 +1,4 @@ #!/usr/bin/env python -import sys +from setuptools import setup -import versioneer -from setuptools import find_packages, setup - -DISTNAME = "xarray" -LICENSE = "Apache" -AUTHOR = "xarray Developers" -AUTHOR_EMAIL = "xarray@googlegroups.com" -URL = "https://github.com/pydata/xarray" -CLASSIFIERS = [ - "Development Status :: 5 - Production/Stable", - "License :: OSI Approved :: Apache Software License", - "Operating System :: OS Independent", - "Intended Audience :: Science/Research", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", - "Topic :: Scientific/Engineering", -] - -PYTHON_REQUIRES = ">=3.6" -INSTALL_REQUIRES = ["numpy >= 1.14", "pandas >= 0.24"] -needs_pytest = {"pytest", "test", "ptr"}.intersection(sys.argv) -SETUP_REQUIRES = ["pytest-runner >= 4.2"] if needs_pytest else [] -TESTS_REQUIRE = ["pytest >= 2.7.1"] - -DESCRIPTION = "N-D labeled arrays and datasets in Python" -LONG_DESCRIPTION = """ -**xarray** (formerly **xray**) is an open source project and Python package -that makes working with labelled multi-dimensional arrays simple, -efficient, and fun! - -Xarray introduces labels in the form of dimensions, coordinates and -attributes on top of raw NumPy_-like arrays, which allows for a more -intuitive, more concise, and less error-prone developer experience. -The package includes a large and growing library of domain-agnostic functions -for advanced analytics and visualization with these data structures. - -Xarray was inspired by and borrows heavily from pandas_, the popular data -analysis package focused on labelled tabular data. -It is particularly tailored to working with netCDF_ files, which were the -source of xarray's data model, and integrates tightly with dask_ for parallel -computing. - -.. _NumPy: https://www.numpy.org -.. _pandas: https://pandas.pydata.org -.. _dask: https://dask.org -.. _netCDF: https://www.unidata.ucar.edu/software/netcdf - -Why xarray? ------------ - -Multi-dimensional (a.k.a. N-dimensional, ND) arrays (sometimes called -"tensors") are an essential part of computational science. -They are encountered in a wide range of fields, including physics, astronomy, -geoscience, bioinformatics, engineering, finance, and deep learning. -In Python, NumPy_ provides the fundamental data structure and API for -working with raw ND arrays. -However, real-world datasets are usually more than just raw numbers; -they have labels which encode information about how the array values map -to locations in space, time, etc. - -Xarray doesn't just keep track of labels on arrays -- it uses them to provide a -powerful and concise interface. For example: - -- Apply operations over dimensions by name: ``x.sum('time')``. -- Select values by label instead of integer location: - ``x.loc['2014-01-01']`` or ``x.sel(time='2014-01-01')``. -- Mathematical operations (e.g., ``x - y``) vectorize across multiple - dimensions (array broadcasting) based on dimension names, not shape. -- Flexible split-apply-combine operations with groupby: - ``x.groupby('time.dayofyear').mean()``. -- Database like alignment based on coordinate labels that smoothly - handles missing values: ``x, y = xr.align(x, y, join='outer')``. -- Keep track of arbitrary metadata in the form of a Python dictionary: - ``x.attrs``. - -Learn more ----------- - -- Documentation: http://xarray.pydata.org -- Issue tracker: http://github.com/pydata/xarray/issues -- Source code: http://github.com/pydata/xarray -- SciPy2015 talk: https://www.youtube.com/watch?v=X0pAhJgySxk -""" - - -setup( - name=DISTNAME, - version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), - license=LICENSE, - author=AUTHOR, - author_email=AUTHOR_EMAIL, - classifiers=CLASSIFIERS, - description=DESCRIPTION, - long_description=LONG_DESCRIPTION, - python_requires=PYTHON_REQUIRES, - install_requires=INSTALL_REQUIRES, - setup_requires=SETUP_REQUIRES, - tests_require=TESTS_REQUIRE, - url=URL, - packages=find_packages(), - package_data={ - "xarray": ["py.typed", "tests/data/*", "static/css/*", "static/html/*"] - }, -) +setup(use_scm_version=True) diff --git a/versioneer.py b/versioneer.py deleted file mode 100644 index 7b55d5d06cb..00000000000 --- a/versioneer.py +++ /dev/null @@ -1,1883 +0,0 @@ -# flake8: noqa - -# Version: 0.18 - -"""The Versioneer - like a rocketeer, but for versions. - -The Versioneer -============== - -* like a rocketeer, but for versions! -* https://github.com/warner/python-versioneer -* Brian Warner -* License: Public Domain -* Compatible With: python2.6, 2.7, 3.2, 3.3, 3.4, 3.5, 3.6, and pypy -* [![Latest Version] -(https://pypip.in/version/versioneer/badge.svg?style=flat) -](https://pypi.python.org/pypi/versioneer/) -* [![Build Status] -(https://travis-ci.org/warner/python-versioneer.png?branch=master) -](https://travis-ci.org/warner/python-versioneer) - -This is a tool for managing a recorded version number in distutils-based -python projects. The goal is to remove the tedious and error-prone "update -the embedded version string" step from your release process. Making a new -release should be as easy as recording a new tag in your version-control -system, and maybe making new tarballs. - - -## Quick Install - -* `pip install versioneer` to somewhere to your $PATH -* add a `[versioneer]` section to your setup.cfg (see below) -* run `versioneer install` in your source tree, commit the results - -## Version Identifiers - -Source trees come from a variety of places: - -* a version-control system checkout (mostly used by developers) -* a nightly tarball, produced by build automation -* a snapshot tarball, produced by a web-based VCS browser, like github's - "tarball from tag" feature -* a release tarball, produced by "setup.py sdist", distributed through PyPI - -Within each source tree, the version identifier (either a string or a number, -this tool is format-agnostic) can come from a variety of places: - -* ask the VCS tool itself, e.g. "git describe" (for checkouts), which knows - about recent "tags" and an absolute revision-id -* the name of the directory into which the tarball was unpacked -* an expanded VCS keyword ($Id$, etc) -* a `_version.py` created by some earlier build step - -For released software, the version identifier is closely related to a VCS -tag. Some projects use tag names that include more than just the version -string (e.g. "myproject-1.2" instead of just "1.2"), in which case the tool -needs to strip the tag prefix to extract the version identifier. For -unreleased software (between tags), the version identifier should provide -enough information to help developers recreate the same tree, while also -giving them an idea of roughly how old the tree is (after version 1.2, before -version 1.3). Many VCS systems can report a description that captures this, -for example `git describe --tags --dirty --always` reports things like -"0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the -0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has -uncommitted changes. - -The version identifier is used for multiple purposes: - -* to allow the module to self-identify its version: `myproject.__version__` -* to choose a name and prefix for a 'setup.py sdist' tarball - -## Theory of Operation - -Versioneer works by adding a special `_version.py` file into your source -tree, where your `__init__.py` can import it. This `_version.py` knows how to -dynamically ask the VCS tool for version information at import time. - -`_version.py` also contains `$Revision$` markers, and the installation -process marks `_version.py` to have this marker rewritten with a tag name -during the `git archive` command. As a result, generated tarballs will -contain enough information to get the proper version. - -To allow `setup.py` to compute a version too, a `versioneer.py` is added to -the top level of your source tree, next to `setup.py` and the `setup.cfg` -that configures it. This overrides several distutils/setuptools commands to -compute the version when invoked, and changes `setup.py build` and `setup.py -sdist` to replace `_version.py` with a small static file that contains just -the generated version data. - -## Installation - -See [INSTALL.md](./INSTALL.md) for detailed installation instructions. - -## Version-String Flavors - -Code which uses Versioneer can learn about its version string at runtime by -importing `_version` from your main `__init__.py` file and running the -`get_versions()` function. From the "outside" (e.g. in `setup.py`), you can -import the top-level `versioneer.py` and run `get_versions()`. - -Both functions return a dictionary with different flavors of version -information: - -* `['version']`: A condensed version string, rendered using the selected - style. This is the most commonly used value for the project's version - string. The default "pep440" style yields strings like `0.11`, - `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the "Styles" section - below for alternative styles. - -* `['full-revisionid']`: detailed revision identifier. For Git, this is the - full SHA1 commit id, e.g. "1076c978a8d3cfc70f408fe5974aa6c092c949ac". - -* `['date']`: Date and time of the latest `HEAD` commit. For Git, it is the - commit date in ISO 8601 format. This will be None if the date is not - available. - -* `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that - this is only accurate if run in a VCS checkout, otherwise it is likely to - be False or None - -* `['error']`: if the version string could not be computed, this will be set - to a string describing the problem, otherwise it will be None. It may be - useful to throw an exception in setup.py if this is set, to avoid e.g. - creating tarballs with a version string of "unknown". - -Some variants are more useful than others. Including `full-revisionid` in a -bug report should allow developers to reconstruct the exact code being tested -(or indicate the presence of local changes that should be shared with the -developers). `version` is suitable for display in an "about" box or a CLI -`--version` output: it can be easily compared against release notes and lists -of bugs fixed in various releases. - -The installer adds the following text to your `__init__.py` to place a basic -version in `YOURPROJECT.__version__`: - - from ._version import get_versions - __version__ = get_versions()['version'] - del get_versions - -## Styles - -The setup.cfg `style=` configuration controls how the VCS information is -rendered into a version string. - -The default style, "pep440", produces a PEP440-compliant string, equal to the -un-prefixed tag name for actual releases, and containing an additional "local -version" section with more detail for in-between builds. For Git, this is -TAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags ---dirty --always`. For example "0.11+2.g1076c97.dirty" indicates that the -tree is like the "1076c97" commit but has uncommitted changes (".dirty"), and -that this commit is two revisions ("+2") beyond the "0.11" tag. For released -software (exactly equal to a known tag), the identifier will only contain the -stripped tag, e.g. "0.11". - -Other styles are available. See [details.md](details.md) in the Versioneer -source tree for descriptions. - -## Debugging - -Versioneer tries to avoid fatal errors: if something goes wrong, it will tend -to return a version of "0+unknown". To investigate the problem, run `setup.py -version`, which will run the version-lookup code in a verbose mode, and will -display the full contents of `get_versions()` (including the `error` string, -which may help identify what went wrong). - -## Known Limitations - -Some situations are known to cause problems for Versioneer. This details the -most significant ones. More can be found on Github -[issues page](https://github.com/warner/python-versioneer/issues). - -### Subprojects - -Versioneer has limited support for source trees in which `setup.py` is not in -the root directory (e.g. `setup.py` and `.git/` are *not* siblings). The are -two common reasons why `setup.py` might not be in the root: - -* Source trees which contain multiple subprojects, such as - [Buildbot](https://github.com/buildbot/buildbot), which contains both - "master" and "slave" subprojects, each with their own `setup.py`, - `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI - distributions (and upload multiple independently-installable tarballs). -* Source trees whose main purpose is to contain a C library, but which also - provide bindings to Python (and perhaps other langauges) in subdirectories. - -Versioneer will look for `.git` in parent directories, and most operations -should get the right version string. However `pip` and `setuptools` have bugs -and implementation details which frequently cause `pip install .` from a -subproject directory to fail to find a correct version string (so it usually -defaults to `0+unknown`). - -`pip install --editable .` should work correctly. `setup.py install` might -work too. - -Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in -some later version. - -[Bug #38](https://github.com/warner/python-versioneer/issues/38) is tracking -this issue. The discussion in -[PR #61](https://github.com/warner/python-versioneer/pull/61) describes the -issue from the Versioneer side in more detail. -[pip PR#3176](https://github.com/pypa/pip/pull/3176) and -[pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve -pip to let Versioneer work correctly. - -Versioneer-0.16 and earlier only looked for a `.git` directory next to the -`setup.cfg`, so subprojects were completely unsupported with those releases. - -### Editable installs with setuptools <= 18.5 - -`setup.py develop` and `pip install --editable .` allow you to install a -project into a virtualenv once, then continue editing the source code (and -test) without re-installing after every change. - -"Entry-point scripts" (`setup(entry_points={"console_scripts": ..})`) are a -convenient way to specify executable scripts that should be installed along -with the python package. - -These both work as expected when using modern setuptools. When using -setuptools-18.5 or earlier, however, certain operations will cause -`pkg_resources.DistributionNotFound` errors when running the entrypoint -script, which must be resolved by re-installing the package. This happens -when the install happens with one version, then the egg_info data is -regenerated while a different version is checked out. Many setup.py commands -cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into -a different virtualenv), so this can be surprising. - -[Bug #83](https://github.com/warner/python-versioneer/issues/83) describes -this one, but upgrading to a newer version of setuptools should probably -resolve it. - -### Unicode version strings - -While Versioneer works (and is continually tested) with both Python 2 and -Python 3, it is not entirely consistent with bytes-vs-unicode distinctions. -Newer releases probably generate unicode version strings on py2. It's not -clear that this is wrong, but it may be surprising for applications when then -write these strings to a network connection or include them in bytes-oriented -APIs like cryptographic checksums. - -[Bug #71](https://github.com/warner/python-versioneer/issues/71) investigates -this question. - - -## Updating Versioneer - -To upgrade your project to a new release of Versioneer, do the following: - -* install the new Versioneer (`pip install -U versioneer` or equivalent) -* edit `setup.cfg`, if necessary, to include any new configuration settings - indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details. -* re-run `versioneer install` in your source tree, to replace - `SRC/_version.py` -* commit any changed files - -## Future Directions - -This tool is designed to make it easily extended to other version-control -systems: all VCS-specific components are in separate directories like -src/git/ . The top-level `versioneer.py` script is assembled from these -components by running make-versioneer.py . In the future, make-versioneer.py -will take a VCS name as an argument, and will construct a version of -`versioneer.py` that is specific to the given VCS. It might also take the -configuration arguments that are currently provided manually during -installation by editing setup.py . Alternatively, it might go the other -direction and include code from all supported VCS systems, reducing the -number of intermediate scripts. - - -## License - -To make Versioneer easier to embed, all its code is dedicated to the public -domain. The `_version.py` that it creates is also in the public domain. -Specifically, both are released under the Creative Commons "Public Domain -Dedication" license (CC0-1.0), as described in -https://creativecommons.org/publicdomain/zero/1.0/ . - -""" - - -import errno -import json -import os -import re -import subprocess -import sys - -try: - import configparser -except ImportError: - import ConfigParser as configparser - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_root(): - """Get the project root directory. - - We require that all commands are run from the project root, i.e. the - directory that contains setup.py, setup.cfg, and versioneer.py . - """ - root = os.path.realpath(os.path.abspath(os.getcwd())) - setup_py = os.path.join(root, "setup.py") - versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - # allow 'python path/to/setup.py COMMAND' - root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) - setup_py = os.path.join(root, "setup.py") - versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - err = ( - "Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND')." - ) - raise VersioneerBadRootError(err) - try: - # Certain runtime workflows (setup.py install/develop in a setuptools - # tree) execute all dependencies in a single python process, so - # "versioneer" may be imported multiple times, and python's shared - # module-import table will cache the first one. So we can't use - # os.path.dirname(__file__), as that will find whichever - # versioneer.py was first imported, even in later projects. - me = os.path.realpath(os.path.abspath(__file__)) - me_dir = os.path.normcase(os.path.splitext(me)[0]) - vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) - if me_dir != vsr_dir: - print( - "Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(me), versioneer_py) - ) - except NameError: - pass - return root - - -def get_config_from_root(root): - """Read the project setup.cfg file to determine Versioneer config.""" - # This might raise EnvironmentError (if setup.cfg is missing), or - # configparser.NoSectionError (if it lacks a [versioneer] section), or - # configparser.NoOptionError (if it lacks "VCS="). See the docstring at - # the top of versioneer.py for instructions on writing your setup.cfg . - setup_cfg = os.path.join(root, "setup.cfg") - parser = configparser.SafeConfigParser() - with open(setup_cfg, "r") as f: - parser.readfp(f) - VCS = parser.get("versioneer", "VCS") # mandatory - - def get(parser, name): - if parser.has_option("versioneer", name): - return parser.get("versioneer", name) - return None - - cfg = VersioneerConfig() - cfg.VCS = VCS - cfg.style = get(parser, "style") or "" - cfg.versionfile_source = get(parser, "versionfile_source") - cfg.versionfile_build = get(parser, "versionfile_build") - cfg.tag_prefix = get(parser, "tag_prefix") - if cfg.tag_prefix in ("''", '""'): - cfg.tag_prefix = "" - cfg.parentdir_prefix = get(parser, "parentdir_prefix") - cfg.verbose = get(parser, "verbose") - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -# these dictionaries contain VCS-specific tools -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - p = None - for c in commands: - try: - dispcmd = str([c] + args) - # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen( - [c] + args, - cwd=cwd, - env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr else None), - ) - break - except OSError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %s" % (commands,)) - return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode - - -LONG_VERSION_PY[ - "git" -] = r''' -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) - -"""Git implementation of _version.py.""" - -import errno -import os -import re -import subprocess -import sys - - -def get_keywords(): - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "%(DOLLAR)sFormat:%%d%(DOLLAR)s" - git_full = "%(DOLLAR)sFormat:%%H%(DOLLAR)s" - git_date = "%(DOLLAR)sFormat:%%ci%(DOLLAR)s" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_config(): - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "%(STYLE)s" - cfg.tag_prefix = "%(TAG_PREFIX)s" - cfg.parentdir_prefix = "%(PARENTDIR_PREFIX)s" - cfg.versionfile_source = "%(VERSIONFILE_SOURCE)s" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - p = None - for c in commands: - try: - dispcmd = str([c] + args) - # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) - break - except EnvironmentError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %%s" %% dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %%s" %% (commands,)) - return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: - if verbose: - print("unable to run %%s (error)" %% dispcmd) - print("stdout was %%s" %% stdout) - return None, p.returncode - return stdout, p.returncode - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for i in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %%s but none started with prefix %%s" %% - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") - date = keywords.get("date") - if date is not None: - # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %%d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) - if verbose: - print("discarding '%%s', no digits" %% ",".join(refs - tags)) - if verbose: - print("likely tags: %%s" %% ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - if verbose: - print("picking %%s" %% r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %%s not under git control" %% root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%%s*" %% tag_prefix], - cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%%s'" - %% describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%%s' doesn't start with prefix '%%s'" - print(fmt %% (full_tag, tag_prefix)) - pieces["error"] = ("tag '%%s' doesn't start with prefix '%%s'" - %% (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%%ci", "HEAD"], - cwd=root)[0].strip() - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%%d.g%%s" %% (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post.dev%%d" %% pieces["distance"] - else: - # exception #1 - rendered = "0.post.dev%%d" %% pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%%s" %% pieces["short"] - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%%s" %% pieces["short"] - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Eexceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%%s'" %% style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -def get_versions(): - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): - root = os.path.dirname(root) - except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} -''' - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except OSError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") - date = keywords.get("date") - if date is not None: - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = {r.strip() for r in refnames.strip("()").split(",")} - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r"\d", r)} - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix) :] - if verbose: - print("picking %s" % r) - return { - "version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": None, - "date": date, - } - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return { - "version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": "no suitable tags", - "date": None, - } - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command( - GITS, - [ - "describe", - "--tags", - "--dirty", - "--always", - "--long", - "--match", - "%s*" % tag_prefix, - ], - cwd=root, - ) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[: git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) - if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( - full_tag, - tag_prefix, - ) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix) :] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ - 0 - ].strip() - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def do_vcs_install(manifest_in, versionfile_source, ipy): - """Git-specific installation logic for Versioneer. - - For Git, this means creating/changing .gitattributes to mark _version.py - for export-subst keyword substitution. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - files = [manifest_in, versionfile_source] - if ipy: - files.append(ipy) - try: - me = __file__ - if me.endswith(".pyc") or me.endswith(".pyo"): - me = os.path.splitext(me)[0] + ".py" - versioneer_file = os.path.relpath(me) - except NameError: - versioneer_file = "versioneer.py" - files.append(versioneer_file) - present = False - try: - f = open(".gitattributes", "r") - for line in f.readlines(): - if line.strip().startswith(versionfile_source): - if "export-subst" in line.strip().split()[1:]: - present = True - f.close() - except OSError: - pass - if not present: - f = open(".gitattributes", "a+") - f.write("%s export-subst\n" % versionfile_source) - f.close() - files.append(".gitattributes") - run_command(GITS, ["add", "--"] + files) - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for i in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return { - "version": dirname[len(parentdir_prefix) :], - "full-revisionid": None, - "dirty": False, - "error": None, - "date": None, - } - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print( - "Tried directories %s but none started with prefix %s" - % (str(rootdirs), parentdir_prefix) - ) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -SHORT_VERSION_PY = """ -# This file was generated by 'versioneer.py' (0.18) from -# revision-control system data, or from the parent directory name of an -# unpacked source archive. Distribution tarballs contain a pre-generated copy -# of this file. - -import json - -version_json = ''' -%s -''' # END VERSION_JSON - - -def get_versions(): - return json.loads(version_json) -""" - - -def versions_from_file(filename): - """Try to determine the version from _version.py if present.""" - try: - with open(filename) as f: - contents = f.read() - except OSError: - raise NotThisMethod("unable to read _version.py") - mo = re.search( - r"version_json = '''\n(.*)''' # END VERSION_JSON", contents, re.M | re.S - ) - if not mo: - mo = re.search( - r"version_json = '''\r\n(.*)''' # END VERSION_JSON", contents, re.M | re.S - ) - if not mo: - raise NotThisMethod("no version_json in _version.py") - return json.loads(mo.group(1)) - - -def write_to_version_file(filename, versions): - """Write the given version number to the given _version.py file.""" - os.unlink(filename) - contents = json.dumps(versions, sort_keys=True, indent=1, separators=(",", ": ")) - with open(filename, "w") as f: - f.write(SHORT_VERSION_PY % contents) - - print("set %s to '%s'" % (filename, versions["version"])) - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] - else: - # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Eexceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return { - "version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None, - } - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return { - "version": rendered, - "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], - "error": None, - "date": pieces.get("date"), - } - - -class VersioneerBadRootError(Exception): - """The project root directory is unknown or missing key files.""" - - -def get_versions(verbose=False): - """Get the project version from whatever source is available. - - Returns dict with two keys: 'version' and 'full'. - """ - if "versioneer" in sys.modules: - # see the discussion in cmdclass.py:get_cmdclass() - del sys.modules["versioneer"] - - root = get_root() - cfg = get_config_from_root(root) - - assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" - handlers = HANDLERS.get(cfg.VCS) - assert handlers, "unrecognized VCS '%s'" % cfg.VCS - verbose = verbose or cfg.verbose - assert ( - cfg.versionfile_source is not None - ), "please set versioneer.versionfile_source" - assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" - - versionfile_abs = os.path.join(root, cfg.versionfile_source) - - # extract version from first of: _version.py, VCS command (e.g. 'git - # describe'), parentdir. This is meant to work for developers using a - # source checkout, for users of a tarball created by 'setup.py sdist', - # and for users of a tarball/zipball created by 'git archive' or github's - # download-from-tag feature or the equivalent in other VCSes. - - get_keywords_f = handlers.get("get_keywords") - from_keywords_f = handlers.get("keywords") - if get_keywords_f and from_keywords_f: - try: - keywords = get_keywords_f(versionfile_abs) - ver = from_keywords_f(keywords, cfg.tag_prefix, verbose) - if verbose: - print("got version from expanded keyword %s" % ver) - return ver - except NotThisMethod: - pass - - try: - ver = versions_from_file(versionfile_abs) - if verbose: - print("got version from file %s %s" % (versionfile_abs, ver)) - return ver - except NotThisMethod: - pass - - from_vcs_f = handlers.get("pieces_from_vcs") - if from_vcs_f: - try: - pieces = from_vcs_f(cfg.tag_prefix, root, verbose) - ver = render(pieces, cfg.style) - if verbose: - print("got version from VCS %s" % ver) - return ver - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - if verbose: - print("got version from parentdir %s" % ver) - return ver - except NotThisMethod: - pass - - if verbose: - print("unable to compute version") - - return { - "version": "0+unknown", - "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", - "date": None, - } - - -def get_version(): - """Get the short version string for this project.""" - return get_versions()["version"] - - -def get_cmdclass(): - """Get the custom setuptools/distutils subclasses used by Versioneer.""" - if "versioneer" in sys.modules: - del sys.modules["versioneer"] - # this fixes the "python setup.py develop" case (also 'install' and - # 'easy_install .'), in which subdependencies of the main project are - # built (using setup.py bdist_egg) in the same python process. Assume - # a main project A and a dependency B, which use different versions - # of Versioneer. A's setup.py imports A's Versioneer, leaving it in - # sys.modules by the time B's setup.py is executed, causing B to run - # with the wrong versioneer. Setuptools wraps the sub-dep builds in a - # sandbox that restores sys.modules to it's pre-build state, so the - # parent is protected against the child's "import versioneer". By - # removing ourselves from sys.modules here, before the child build - # happens, we protect the child from the parent's versioneer too. - # Also see https://github.com/warner/python-versioneer/issues/52 - - cmds = {} - - # we add "version" to both distutils and setuptools - from distutils.core import Command - - class cmd_version(Command): - description = "report generated version string" - user_options = [] - boolean_options = [] - - def initialize_options(self): - pass - - def finalize_options(self): - pass - - def run(self): - vers = get_versions(verbose=True) - print("Version: %s" % vers["version"]) - print(" full-revisionid: %s" % vers.get("full-revisionid")) - print(" dirty: %s" % vers.get("dirty")) - print(" date: %s" % vers.get("date")) - if vers["error"]: - print(" error: %s" % vers["error"]) - - cmds["version"] = cmd_version - - # we override "build_py" in both distutils and setuptools - # - # most invocation pathways end up running build_py: - # distutils/build -> build_py - # distutils/install -> distutils/build ->.. - # setuptools/bdist_wheel -> distutils/install ->.. - # setuptools/bdist_egg -> distutils/install_lib -> build_py - # setuptools/install -> bdist_egg ->.. - # setuptools/develop -> ? - # pip install: - # copies source tree to a tempdir before running egg_info/etc - # if .git isn't copied too, 'git describe' will fail - # then does setup.py bdist_wheel, or sometimes setup.py install - # setup.py egg_info -> ? - - # we override different "build_py" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.build_py import build_py as _build_py - else: - from distutils.command.build_py import build_py as _build_py - - class cmd_build_py(_build_py): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - _build_py.run(self) - # now locate _version.py in the new build/ directory and replace - # it with an updated value - if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build) - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - cmds["build_py"] = cmd_build_py - - if "cx_Freeze" in sys.modules: # cx_freeze enabled? - from cx_Freeze.dist import build_exe as _build_exe - - # nczeczulin reports that py2exe won't like the pep440-style string - # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. - # setup(console=[{ - # "version": versioneer.get_version().split("+", 1)[0], # FILEVERSION - # "product_version": versioneer.get_version(), - # ... - - class cmd_build_exe(_build_exe): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - target_versionfile = cfg.versionfile_source - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - _build_exe.run(self) - os.unlink(target_versionfile) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write( - LONG - % { - "DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - } - ) - - cmds["build_exe"] = cmd_build_exe - del cmds["build_py"] - - if "py2exe" in sys.modules: # py2exe enabled? - try: - from py2exe.distutils_buildexe import py2exe as _py2exe # py3 - except ImportError: - from py2exe.build_exe import py2exe as _py2exe # py2 - - class cmd_py2exe(_py2exe): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - target_versionfile = cfg.versionfile_source - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - _py2exe.run(self) - os.unlink(target_versionfile) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write( - LONG - % { - "DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - } - ) - - cmds["py2exe"] = cmd_py2exe - - # we override different "sdist" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.sdist import sdist as _sdist - else: - from distutils.command.sdist import sdist as _sdist - - class cmd_sdist(_sdist): - def run(self): - versions = get_versions() - self._versioneer_generated_versions = versions - # unless we update this, the command will keep using the old - # version - self.distribution.metadata.version = versions["version"] - return _sdist.run(self) - - def make_release_tree(self, base_dir, files): - root = get_root() - cfg = get_config_from_root(root) - _sdist.make_release_tree(self, base_dir, files) - # now locate _version.py in the new base_dir directory - # (remembering that it may be a hardlink) and replace it with an - # updated value - target_versionfile = os.path.join(base_dir, cfg.versionfile_source) - print("UPDATING %s" % target_versionfile) - write_to_version_file( - target_versionfile, self._versioneer_generated_versions - ) - - cmds["sdist"] = cmd_sdist - - return cmds - - -CONFIG_ERROR = """ -setup.cfg is missing the necessary Versioneer configuration. You need -a section like: - - [versioneer] - VCS = git - style = pep440 - versionfile_source = src/myproject/_version.py - versionfile_build = myproject/_version.py - tag_prefix = - parentdir_prefix = myproject- - -You will also need to edit your setup.py to use the results: - - import versioneer - setup(version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), ...) - -Please read the docstring in ./versioneer.py for configuration instructions, -edit setup.cfg, and re-run the installer or 'python versioneer.py setup'. -""" - -SAMPLE_CONFIG = """ -# See the docstring in versioneer.py for instructions. Note that you must -# re-run 'versioneer.py setup' after changing this section, and commit the -# resulting files. - -[versioneer] -#VCS = git -#style = pep440 -#versionfile_source = -#versionfile_build = -#tag_prefix = -#parentdir_prefix = - -""" - -INIT_PY_SNIPPET = """ -from ._version import get_versions -__version__ = get_versions()['version'] -del get_versions -""" - - -def do_setup(): - """Main VCS-independent setup function for installing Versioneer.""" - root = get_root() - try: - cfg = get_config_from_root(root) - except (OSError, configparser.NoSectionError, configparser.NoOptionError) as e: - if isinstance(e, (EnvironmentError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", file=sys.stderr) - with open(os.path.join(root, "setup.cfg"), "a") as f: - f.write(SAMPLE_CONFIG) - print(CONFIG_ERROR, file=sys.stderr) - return 1 - - print(" creating %s" % cfg.versionfile_source) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write( - LONG - % { - "DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - } - ) - - ipy = os.path.join(os.path.dirname(cfg.versionfile_source), "__init__.py") - if os.path.exists(ipy): - try: - with open(ipy, "r") as f: - old = f.read() - except OSError: - old = "" - if INIT_PY_SNIPPET not in old: - print(" appending to %s" % ipy) - with open(ipy, "a") as f: - f.write(INIT_PY_SNIPPET) - else: - print(" %s unmodified" % ipy) - else: - print(" %s doesn't exist, ok" % ipy) - ipy = None - - # Make sure both the top-level "versioneer.py" and versionfile_source - # (PKG/_version.py, used by runtime code) are in MANIFEST.in, so - # they'll be copied into source distributions. Pip won't be able to - # install the package without this. - manifest_in = os.path.join(root, "MANIFEST.in") - simple_includes = set() - try: - with open(manifest_in, "r") as f: - for line in f: - if line.startswith("include "): - for include in line.split()[1:]: - simple_includes.add(include) - except OSError: - pass - # That doesn't cover everything MANIFEST.in can do - # (http://docs.python.org/2/distutils/sourcedist.html#commands), so - # it might give some false negatives. Appending redundant 'include' - # lines is safe, though. - if "versioneer.py" not in simple_includes: - print(" appending 'versioneer.py' to MANIFEST.in") - with open(manifest_in, "a") as f: - f.write("include versioneer.py\n") - else: - print(" 'versioneer.py' already in MANIFEST.in") - if cfg.versionfile_source not in simple_includes: - print( - " appending versionfile_source ('%s') to MANIFEST.in" - % cfg.versionfile_source - ) - with open(manifest_in, "a") as f: - f.write("include %s\n" % cfg.versionfile_source) - else: - print(" versionfile_source already in MANIFEST.in") - - # Make VCS-specific changes. For git, this means creating/changing - # .gitattributes to mark _version.py for export-subst keyword - # substitution. - do_vcs_install(manifest_in, cfg.versionfile_source, ipy) - return 0 - - -def scan_setup_py(): - """Validate the contents of setup.py against Versioneer's expectations.""" - found = set() - setters = False - errors = 0 - with open("setup.py", "r") as f: - for line in f.readlines(): - if "import versioneer" in line: - found.add("import") - if "versioneer.get_cmdclass()" in line: - found.add("cmdclass") - if "versioneer.get_version()" in line: - found.add("get_version") - if "versioneer.VCS" in line: - setters = True - if "versioneer.versionfile_source" in line: - setters = True - if len(found) != 3: - print("") - print("Your setup.py appears to be missing some important items") - print("(but I might be wrong). Please make sure it has something") - print("roughly like the following:") - print("") - print(" import versioneer") - print(" setup( version=versioneer.get_version(),") - print(" cmdclass=versioneer.get_cmdclass(), ...)") - print("") - errors += 1 - if setters: - print("You should remove lines like 'versioneer.VCS = ' and") - print("'versioneer.versionfile_source = ' . This configuration") - print("now lives in setup.cfg, and should be removed from setup.py") - print("") - errors += 1 - return errors - - -if __name__ == "__main__": - cmd = sys.argv[1] - if cmd == "setup": - errors = do_setup() - errors += scan_setup_py() - if errors: - sys.exit(1) diff --git a/xarray/__init__.py b/xarray/__init__.py index 394dd0f80bc..331d8ecb09a 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -1,44 +1,88 @@ -""" isort:skip_file """ -# flake8: noqa - -from ._version import get_versions - -__version__ = get_versions()["version"] -del get_versions - -from .core.alignment import align, broadcast -from .core.common import full_like, zeros_like, ones_like -from .core.concat import concat -from .core.combine import combine_by_coords, combine_nested, auto_combine -from .core.computation import apply_ufunc, dot, where -from .core.extensions import register_dataarray_accessor, register_dataset_accessor -from .core.variable import as_variable, Variable, IndexVariable, Coordinate -from .core.dataset import Dataset -from .core.dataarray import DataArray -from .core.merge import merge, MergeError -from .core.options import set_options -from .core.parallel import map_blocks +import pkg_resources +from . import testing, tutorial, ufuncs from .backends.api import ( - open_dataset, + load_dataarray, + load_dataset, open_dataarray, + open_dataset, open_mfdataset, save_mfdataset, - load_dataset, - load_dataarray, ) from .backends.rasterio_ import open_rasterio from .backends.zarr import open_zarr - -from .conventions import decode_cf, SerializationWarning - from .coding.cftime_offsets import cftime_range from .coding.cftimeindex import CFTimeIndex - +from .conventions import SerializationWarning, decode_cf +from .core.alignment import align, broadcast +from .core.combine import auto_combine, combine_by_coords, combine_nested +from .core.common import ALL_DIMS, full_like, ones_like, zeros_like +from .core.computation import apply_ufunc, dot, where +from .core.concat import concat +from .core.dataarray import DataArray +from .core.dataset import Dataset +from .core.extensions import register_dataarray_accessor, register_dataset_accessor +from .core.merge import MergeError, merge +from .core.options import set_options +from .core.parallel import map_blocks +from .core.variable import Coordinate, IndexVariable, Variable, as_variable from .util.print_versions import show_versions -from . import tutorial -from . import ufuncs -from . import testing +try: + __version__ = pkg_resources.get_distribution("xarray").version +except Exception: + # Local copy or not installed with setuptools. + # Disable minimum version checks on downstream libraries. + __version__ = "999" -from .core.common import ALL_DIMS +# A hardcoded __all__ variable is necessary to appease +# `mypy --strict` running in projects that import xarray. +__all__ = ( + # Sub-packages + "ufuncs", + "testing", + "tutorial", + # Top-level functions + "align", + "apply_ufunc", + "as_variable", + "auto_combine", + "broadcast", + "cftime_range", + "combine_by_coords", + "combine_nested", + "concat", + "decode_cf", + "dot", + "full_like", + "load_dataarray", + "load_dataset", + "map_blocks", + "merge", + "ones_like", + "open_dataarray", + "open_dataset", + "open_mfdataset", + "open_rasterio", + "open_zarr", + "register_dataarray_accessor", + "register_dataset_accessor", + "save_mfdataset", + "set_options", + "show_versions", + "where", + "zeros_like", + # Classes + "CFTimeIndex", + "Coordinate", + "DataArray", + "Dataset", + "IndexVariable", + "Variable", + # Exceptions + "MergeError", + "SerializationWarning", + # Constants + "__version__", + "ALL_DIMS", +) diff --git a/xarray/_version.py b/xarray/_version.py deleted file mode 100644 index 0ccb33a5e56..00000000000 --- a/xarray/_version.py +++ /dev/null @@ -1,555 +0,0 @@ -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) - -"""Git implementation of _version.py.""" - -import errno -import os -import re -import subprocess -import sys - - -def get_keywords(): - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "$Format:%d$" - git_full = "$Format:%H$" - git_date = "$Format:%ci$" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_config(): - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "pep440" - cfg.tag_prefix = "v" - cfg.parentdir_prefix = "xarray-" - cfg.versionfile_source = "xarray/_version.py" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - p = None - for c in commands: - try: - dispcmd = str([c] + args) - # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen( - [c] + args, - cwd=cwd, - env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr else None), - ) - break - except OSError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print(f"unable to find command, tried {commands}") - return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for i in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return { - "version": dirname[len(parentdir_prefix) :], - "full-revisionid": None, - "dirty": False, - "error": None, - "date": None, - } - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print( - "Tried directories %s but none started with prefix %s" - % (str(rootdirs), parentdir_prefix) - ) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except OSError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") - date = keywords.get("date") - if date is not None: - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = {r.strip() for r in refnames.strip("()").split(",")} - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r"\d", r)} - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix) :] - if verbose: - print("picking %s" % r) - return { - "version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": None, - "date": date, - } - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return { - "version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": "no suitable tags", - "date": None, - } - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command( - GITS, - [ - "describe", - "--tags", - "--dirty", - "--always", - "--long", - "--match", - "%s*" % tag_prefix, - ], - cwd=root, - ) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[: git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) - if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = "tag '{}' doesn't start with prefix '{}'".format( - full_tag, tag_prefix - ) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix) :] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ - 0 - ].strip() - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] - else: - # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Eexceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return { - "version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None, - } - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return { - "version": rendered, - "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], - "error": None, - "date": pieces.get("date"), - } - - -def get_versions(): - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for i in cfg.versionfile_source.split("/"): - root = os.path.dirname(root) - except NameError: - return { - "version": "0+unknown", - "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None, - } - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return { - "version": "0+unknown", - "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", - "date": None, - } diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 23d09ba5e33..56cd0649989 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -18,13 +18,16 @@ import numpy as np -from .. import DataArray, Dataset, auto_combine, backends, coding, conventions +from .. import backends, coding, conventions from ..core import indexing from ..core.combine import ( _infer_concat_order_from_positions, _nested_combine, + auto_combine, combine_by_coords, ) +from ..core.dataarray import DataArray +from ..core.dataset import Dataset from ..core.utils import close_on_error, is_grib_path, is_remote_uri from .common import AbstractDataStore, ArrayWriter from .locks import _get_scheduler @@ -503,7 +506,7 @@ def maybe_decode_store(store, lock=False): elif engine == "pydap": store = backends.PydapDataStore.open(filename_or_obj, **backend_kwargs) elif engine == "h5netcdf": - store = backends.H5NetCDFStore( + store = backends.H5NetCDFStore.open( filename_or_obj, group=group, lock=lock, **backend_kwargs ) elif engine == "pynio": @@ -527,7 +530,7 @@ def maybe_decode_store(store, lock=False): if engine == "scipy": store = backends.ScipyDataStore(filename_or_obj, **backend_kwargs) elif engine == "h5netcdf": - store = backends.H5NetCDFStore( + store = backends.H5NetCDFStore.open( filename_or_obj, group=group, lock=lock, **backend_kwargs ) @@ -718,6 +721,7 @@ def open_mfdataset( autoclose=None, parallel=False, join="outer", + attrs_file=None, **kwargs, ): """Open multiple files as a single dataset. @@ -729,8 +733,8 @@ def open_mfdataset( ``combine_by_coords`` and ``combine_nested``. By default the old (now deprecated) ``auto_combine`` will be used, please specify either ``combine='by_coords'`` or ``combine='nested'`` in future. Requires dask to be installed. See documentation for - details on dask [1]_. Attributes from the first dataset file are used for the - combined dataset. + details on dask [1]_. Global attributes from the ``attrs_file`` are used + for the combined dataset. Parameters ---------- @@ -827,6 +831,10 @@ def open_mfdataset( - 'override': if indexes are of same size, rewrite indexes to be those of the first object with that dimension. Indexes for the same dimension must have the same size in all objects. + attrs_file : str or pathlib.Path, optional + Path of the file used to read global attributes from. + By default global attributes are read from the first file provided, + with wildcard matches sorted by filename. **kwargs : optional Additional arguments passed on to :py:func:`xarray.open_dataset`. @@ -961,14 +969,22 @@ def open_mfdataset( raise combined._file_obj = _MultiFileCloser(file_objs) - combined.attrs = datasets[0].attrs + + # read global attributes from the attrs_file or from the first dataset + if attrs_file is not None: + if isinstance(attrs_file, Path): + attrs_file = str(attrs_file) + combined.attrs = datasets[paths.index(attrs_file)].attrs + else: + combined.attrs = datasets[0].attrs + return combined WRITEABLE_STORES: Dict[str, Callable] = { "netcdf4": backends.NetCDF4DataStore.open, "scipy": backends.ScipyDataStore, - "h5netcdf": backends.H5NetCDFStore, + "h5netcdf": backends.H5NetCDFStore.open, } diff --git a/xarray/backends/cfgrib_.py b/xarray/backends/cfgrib_.py index 97c9ac1b9b4..bd946df89b2 100644 --- a/xarray/backends/cfgrib_.py +++ b/xarray/backends/cfgrib_.py @@ -1,8 +1,8 @@ import numpy as np -from .. import Variable from ..core import indexing from ..core.utils import Frozen, FrozenDict +from ..core.variable import Variable from .common import AbstractDataStore, BackendArray from .locks import SerializableLock, ensure_lock diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 9786e0a0203..fa3ee19f542 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -34,7 +34,7 @@ def find_root_and_group(ds): """Find the root and group name of a netCDF4/h5netcdf dataset.""" hierarchy = () while ds.parent is not None: - hierarchy = (ds.name,) + hierarchy + hierarchy = (ds.name.split("/")[-1],) + hierarchy ds = ds.parent group = "/" + "/".join(hierarchy) return ds, group diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 51ed512f98b..393db14a7e9 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -1,12 +1,13 @@ import functools +from distutils.version import LooseVersion import numpy as np -from .. import Variable from ..core import indexing -from ..core.utils import FrozenDict -from .common import WritableCFDataStore -from .file_manager import CachingFileManager +from ..core.utils import FrozenDict, is_remote_uri +from ..core.variable import Variable +from .common import WritableCFDataStore, find_root_and_group +from .file_manager import CachingFileManager, DummyFileManager from .locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock from .netCDF4_ import ( BaseNetCDF4Array, @@ -69,8 +70,47 @@ class H5NetCDFStore(WritableCFDataStore): """Store for reading and writing data via h5netcdf """ - def __init__( - self, + __slots__ = ( + "autoclose", + "format", + "is_remote", + "lock", + "_filename", + "_group", + "_manager", + "_mode", + ) + + def __init__(self, manager, group=None, mode=None, lock=HDF5_LOCK, autoclose=False): + + import h5netcdf + + if isinstance(manager, (h5netcdf.File, h5netcdf.Group)): + if group is None: + root, group = find_root_and_group(manager) + else: + if not type(manager) is h5netcdf.File: + raise ValueError( + "must supply a h5netcdf.File if the group " + "argument is provided" + ) + root = manager + manager = DummyFileManager(root) + + self._manager = manager + self._group = group + self._mode = mode + self.format = None + # todo: utilizing find_root_and_group seems a bit clunky + # making filename available on h5netcdf.Group seems better + self._filename = find_root_and_group(self.ds)[0].filename + self.is_remote = is_remote_uri(self._filename) + self.lock = ensure_lock(lock) + self.autoclose = autoclose + + @classmethod + def open( + cls, filename, mode="r", format=None, @@ -78,6 +118,7 @@ def __init__( lock=None, autoclose=False, invalid_netcdf=None, + phony_dims=None, ): import h5netcdf @@ -85,10 +126,14 @@ def __init__( raise ValueError("invalid format for h5netcdf backend") kwargs = {"invalid_netcdf": invalid_netcdf} - - self._manager = CachingFileManager( - h5netcdf.File, filename, mode=mode, kwargs=kwargs - ) + if phony_dims is not None: + if LooseVersion(h5netcdf.__version__) >= LooseVersion("0.8.0"): + kwargs["phony_dims"] = phony_dims + else: + raise ValueError( + "h5netcdf backend keyword argument 'phony_dims' needs " + "h5netcdf >= 0.8.0." + ) if lock is None: if mode == "r": @@ -96,12 +141,8 @@ def __init__( else: lock = combine_locks([HDF5_LOCK, get_write_lock(filename)]) - self._group = group - self.format = format - self._filename = filename - self._mode = mode - self.lock = ensure_lock(lock) - self.autoclose = autoclose + manager = CachingFileManager(h5netcdf.File, filename, mode=mode, kwargs=kwargs) + return cls(manager, group=group, mode=mode, lock=lock, autoclose=autoclose) def _acquire(self, needs_lock=True): with self._manager.acquire_context(needs_lock) as root: diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 0e454ec47de..0a917cde4d7 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -4,10 +4,11 @@ import numpy as np -from .. import Variable, coding +from .. import coding from ..coding.variables import pop_to from ..core import indexing from ..core.utils import FrozenDict, is_remote_uri +from ..core.variable import Variable from .common import ( BackendArray, WritableCFDataStore, diff --git a/xarray/backends/netcdf3.py b/xarray/backends/netcdf3.py index d26b6ce2ea9..c9c4baf9b01 100644 --- a/xarray/backends/netcdf3.py +++ b/xarray/backends/netcdf3.py @@ -2,7 +2,8 @@ import numpy as np -from .. import Variable, coding +from .. import coding +from ..core.variable import Variable # Special characters that are permitted in netCDF names except in the # 0th position of the string diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index 63a7c3b0609..17a4eb8f6bf 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -1,8 +1,8 @@ import numpy as np -from .. import Variable from ..core import indexing from ..core.utils import Frozen, FrozenDict +from ..core.variable import Variable from .common import AbstractDataStore, BackendArray from .file_manager import CachingFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index 7ef4ec66241..20e943ab561 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -1,9 +1,9 @@ import numpy as np -from .. import Variable from ..core import indexing from ..core.pycompat import integer_types from ..core.utils import Frozen, FrozenDict, is_dict_like +from ..core.variable import Variable from .common import AbstractDataStore, BackendArray, robust_getitem diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index d9e372ceaf9..1c66ff1ee48 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -1,8 +1,8 @@ import numpy as np -from .. import Variable from ..core import indexing from ..core.utils import Frozen, FrozenDict +from ..core.variable import Variable from .common import AbstractDataStore, BackendArray from .file_manager import CachingFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, SerializableLock, combine_locks, ensure_lock diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 1f9b9943573..d041e430db9 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -3,8 +3,8 @@ import numpy as np -from .. import DataArray from ..core import indexing +from ..core.dataarray import DataArray from ..core.utils import is_scalar from .common import BackendArray from .file_manager import CachingFileManager diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 3a4787fc67e..9863285d6de 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -2,9 +2,9 @@ import numpy as np -from .. import Variable from ..core.indexing import NumpyIndexingAdapter from ..core.utils import Frozen, FrozenDict +from ..core.variable import Variable from .common import BackendArray, WritableCFDataStore from .file_manager import CachingFileManager, DummyFileManager from .locks import ensure_lock, get_write_lock diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 6d4ebb02a11..763769dac74 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -2,10 +2,11 @@ import numpy as np -from .. import Variable, coding, conventions +from .. import coding, conventions from ..core import indexing from ..core.pycompat import integer_types from ..core.utils import FrozenDict, HiddenKeyDict +from ..core.variable import Variable from .common import AbstractWritableDataStore, BackendArray, _encode_variable_name # need some special secret attributes to tell us the dimensions diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index 4005d4fbf6d..8b440812ca9 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -430,7 +430,14 @@ def __sub__(self, other): import cftime if isinstance(other, (CFTimeIndex, cftime.datetime)): - return pd.TimedeltaIndex(np.array(self) - np.array(other)) + try: + return pd.TimedeltaIndex(np.array(self) - np.array(other)) + except OverflowError: + raise ValueError( + "The time difference exceeds the range of values " + "that can be expressed at the nanosecond resolution." + ) + elif isinstance(other, pd.TimedeltaIndex): return CFTimeIndex(np.array(self) - other.to_pytimedelta()) else: diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index 2b5f87ab0cd..28ead397461 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -148,6 +148,7 @@ class CFMaskCoder(VariableCoder): def encode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_encoding(variable) + dtype = np.dtype(encoding.get("dtype", data.dtype)) fv = encoding.get("_FillValue") mv = encoding.get("missing_value") @@ -162,14 +163,14 @@ def encode(self, variable, name=None): if fv is not None: # Ensure _FillValue is cast to same dtype as data's - encoding["_FillValue"] = data.dtype.type(fv) + encoding["_FillValue"] = dtype.type(fv) fill_value = pop_to(encoding, attrs, "_FillValue", name=name) if not pd.isnull(fill_value): data = duck_array_ops.fillna(data, fill_value) if mv is not None: # Ensure missing_value is cast to same dtype as data's - encoding["missing_value"] = data.dtype.type(mv) + encoding["missing_value"] = dtype.type(mv) fill_value = pop_to(encoding, attrs, "missing_value", name=name) if not pd.isnull(fill_value) and fv is None: data = duck_array_ops.fillna(data, fill_value) diff --git a/xarray/conventions.py b/xarray/conventions.py index a83b4b31c17..a8b9906c153 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -5,7 +5,7 @@ import pandas as pd from .coding import strings, times, variables -from .coding.variables import SerializationWarning +from .coding.variables import SerializationWarning, pop_to from .core import duck_array_ops, indexing from .core.common import contains_cftime_datetimes from .core.pycompat import dask_array_type @@ -660,34 +660,46 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names): and set(target_dims) <= set(v.dims) ): variable_coordinates[k].add(coord_name) - global_coordinates.discard(coord_name) variables = {k: v.copy(deep=False) for k, v in variables.items()} - # These coordinates are saved according to CF conventions - for var_name, coord_names in variable_coordinates.items(): - attrs = variables[var_name].attrs - if "coordinates" in attrs: + # keep track of variable names written to file under the "coordinates" attributes + written_coords = set() + for name, var in variables.items(): + encoding = var.encoding + attrs = var.attrs + if "coordinates" in attrs and "coordinates" in encoding: raise ValueError( - "cannot serialize coordinates because variable " - "%s already has an attribute 'coordinates'" % var_name + f"'coordinates' found in both attrs and encoding for variable {name!r}." ) - attrs["coordinates"] = " ".join(map(str, coord_names)) + + # this will copy coordinates from encoding to attrs if "coordinates" in attrs + # after the next line, "coordinates" is never in encoding + # we get support for attrs["coordinates"] for free. + coords_str = pop_to(encoding, attrs, "coordinates") + if not coords_str and variable_coordinates[name]: + attrs["coordinates"] = " ".join(map(str, variable_coordinates[name])) + if "coordinates" in attrs: + written_coords.update(attrs["coordinates"].split()) # These coordinates are not associated with any particular variables, so we # save them under a global 'coordinates' attribute so xarray can roundtrip # the dataset faithfully. Because this serialization goes beyond CF # conventions, only do it if necessary. # Reference discussion: - # http://mailman.cgd.ucar.edu/pipermail/cf-metadata/2014/057771.html + # http://mailman.cgd.ucar.edu/pipermail/cf-metadata/2014/007571.html + global_coordinates.difference_update(written_coords) if global_coordinates: attributes = dict(attributes) if "coordinates" in attributes: - raise ValueError( - "cannot serialize coordinates because the global " - "attribute 'coordinates' already exists" + warnings.warn( + f"cannot serialize global coordinates {global_coordinates!r} because the global " + f"attribute 'coordinates' already exists. This may prevent faithful roundtripping" + f"of xarray datasets", + SerializationWarning, ) - attributes["coordinates"] = " ".join(map(str, global_coordinates)) + else: + attributes["coordinates"] = " ".join(map(str, global_coordinates)) return variables, attributes diff --git a/xarray/core/accessor_dt.py b/xarray/core/accessor_dt.py index aff6fbc6691..c407371f9f0 100644 --- a/xarray/core/accessor_dt.py +++ b/xarray/core/accessor_dt.py @@ -1,7 +1,11 @@ import numpy as np import pandas as pd -from .common import _contains_datetime_like_objects, is_np_datetime_like +from .common import ( + _contains_datetime_like_objects, + is_np_datetime_like, + is_np_timedelta_like, +) from .pycompat import dask_array_type @@ -145,37 +149,8 @@ def _strftime(values, date_format): return access_method(values, date_format) -class DatetimeAccessor: - """Access datetime fields for DataArrays with datetime-like dtypes. - - Similar to pandas, fields can be accessed through the `.dt` attribute - for applicable DataArrays: - - >>> ds = xarray.Dataset({'time': pd.date_range(start='2000/01/01', - ... freq='D', periods=100)}) - >>> ds.time.dt - - >>> ds.time.dt.dayofyear[:5] - - array([1, 2, 3, 4, 5], dtype=int32) - Coordinates: - * time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03 ... - - All of the pandas fields are accessible here. Note that these fields are - not calendar-aware; if your datetimes are encoded with a non-Gregorian - calendar (e.g. a 360-day calendar) using cftime, then some fields like - `dayofyear` may not be accurate. - - """ - +class Properties: def __init__(self, obj): - if not _contains_datetime_like_objects(obj): - raise TypeError( - "'dt' accessor only available for " - "DataArray with datetime64 timedelta64 dtype or " - "for arrays containing cftime datetime " - "objects." - ) self._obj = obj def _tslib_field_accessor( # type: ignore @@ -194,48 +169,6 @@ def f(self, dtype=dtype): f.__doc__ = docstring return property(f) - year = _tslib_field_accessor("year", "The year of the datetime", np.int64) - month = _tslib_field_accessor( - "month", "The month as January=1, December=12", np.int64 - ) - day = _tslib_field_accessor("day", "The days of the datetime", np.int64) - hour = _tslib_field_accessor("hour", "The hours of the datetime", np.int64) - minute = _tslib_field_accessor("minute", "The minutes of the datetime", np.int64) - second = _tslib_field_accessor("second", "The seconds of the datetime", np.int64) - microsecond = _tslib_field_accessor( - "microsecond", "The microseconds of the datetime", np.int64 - ) - nanosecond = _tslib_field_accessor( - "nanosecond", "The nanoseconds of the datetime", np.int64 - ) - weekofyear = _tslib_field_accessor( - "weekofyear", "The week ordinal of the year", np.int64 - ) - week = weekofyear - dayofweek = _tslib_field_accessor( - "dayofweek", "The day of the week with Monday=0, Sunday=6", np.int64 - ) - weekday = dayofweek - - weekday_name = _tslib_field_accessor( - "weekday_name", "The name of day in a week (ex: Friday)", object - ) - - dayofyear = _tslib_field_accessor( - "dayofyear", "The ordinal day of the year", np.int64 - ) - quarter = _tslib_field_accessor("quarter", "The quarter of the date") - days_in_month = _tslib_field_accessor( - "days_in_month", "The number of days in the month", np.int64 - ) - daysinmonth = days_in_month - - season = _tslib_field_accessor("season", "Season of the year (ex: DJF)", object) - - time = _tslib_field_accessor( - "time", "Timestamps corresponding to datetimes", object - ) - def _tslib_round_accessor(self, name, freq): obj_type = type(self._obj) result = _round_field(self._obj.data, name, freq) @@ -290,6 +223,50 @@ def round(self, freq): """ return self._tslib_round_accessor("round", freq) + +class DatetimeAccessor(Properties): + """Access datetime fields for DataArrays with datetime-like dtypes. + + Fields can be accessed through the `.dt` attribute + for applicable DataArrays. + + Notes + ------ + Note that these fields are not calendar-aware; if your datetimes are encoded + with a non-Gregorian calendar (e.g. a 360-day calendar) using cftime, + then some fields like `dayofyear` may not be accurate. + + Examples + --------- + >>> import xarray as xr + >>> import pandas as pd + >>> dates = pd.date_range(start='2000/01/01', freq='D', periods=10) + >>> ts = xr.DataArray(dates, dims=('time')) + >>> ts + + array(['2000-01-01T00:00:00.000000000', '2000-01-02T00:00:00.000000000', + '2000-01-03T00:00:00.000000000', '2000-01-04T00:00:00.000000000', + '2000-01-05T00:00:00.000000000', '2000-01-06T00:00:00.000000000', + '2000-01-07T00:00:00.000000000', '2000-01-08T00:00:00.000000000', + '2000-01-09T00:00:00.000000000', '2000-01-10T00:00:00.000000000'], + dtype='datetime64[ns]') + Coordinates: + * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-10 + >>> ts.dt + + >>> ts.dt.dayofyear + + array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + Coordinates: + * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-10 + >>> ts.dt.quarter + + array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) + Coordinates: + * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-10 + + """ + def strftime(self, date_format): ''' Return an array of formatted strings specified by date_format, which @@ -323,3 +300,163 @@ def strftime(self, date_format): return obj_type( result, name="strftime", coords=self._obj.coords, dims=self._obj.dims ) + + year = Properties._tslib_field_accessor( + "year", "The year of the datetime", np.int64 + ) + month = Properties._tslib_field_accessor( + "month", "The month as January=1, December=12", np.int64 + ) + day = Properties._tslib_field_accessor("day", "The days of the datetime", np.int64) + hour = Properties._tslib_field_accessor( + "hour", "The hours of the datetime", np.int64 + ) + minute = Properties._tslib_field_accessor( + "minute", "The minutes of the datetime", np.int64 + ) + second = Properties._tslib_field_accessor( + "second", "The seconds of the datetime", np.int64 + ) + microsecond = Properties._tslib_field_accessor( + "microsecond", "The microseconds of the datetime", np.int64 + ) + nanosecond = Properties._tslib_field_accessor( + "nanosecond", "The nanoseconds of the datetime", np.int64 + ) + weekofyear = Properties._tslib_field_accessor( + "weekofyear", "The week ordinal of the year", np.int64 + ) + week = weekofyear + dayofweek = Properties._tslib_field_accessor( + "dayofweek", "The day of the week with Monday=0, Sunday=6", np.int64 + ) + weekday = dayofweek + + weekday_name = Properties._tslib_field_accessor( + "weekday_name", "The name of day in a week", object + ) + + dayofyear = Properties._tslib_field_accessor( + "dayofyear", "The ordinal day of the year", np.int64 + ) + quarter = Properties._tslib_field_accessor("quarter", "The quarter of the date") + days_in_month = Properties._tslib_field_accessor( + "days_in_month", "The number of days in the month", np.int64 + ) + daysinmonth = days_in_month + + season = Properties._tslib_field_accessor("season", "Season of the year", object) + + time = Properties._tslib_field_accessor( + "time", "Timestamps corresponding to datetimes", object + ) + + is_month_start = Properties._tslib_field_accessor( + "is_month_start", + "Indicates whether the date is the first day of the month.", + bool, + ) + is_month_end = Properties._tslib_field_accessor( + "is_month_end", "Indicates whether the date is the last day of the month.", bool + ) + is_quarter_start = Properties._tslib_field_accessor( + "is_quarter_start", + "Indicator for whether the date is the first day of a quarter.", + bool, + ) + is_quarter_end = Properties._tslib_field_accessor( + "is_quarter_end", + "Indicator for whether the date is the last day of a quarter.", + bool, + ) + is_year_start = Properties._tslib_field_accessor( + "is_year_start", "Indicate whether the date is the first day of a year.", bool + ) + is_year_end = Properties._tslib_field_accessor( + "is_year_end", "Indicate whether the date is the last day of the year.", bool + ) + is_leap_year = Properties._tslib_field_accessor( + "is_leap_year", "Boolean indicator if the date belongs to a leap year.", bool + ) + + +class TimedeltaAccessor(Properties): + """Access Timedelta fields for DataArrays with Timedelta-like dtypes. + + Fields can be accessed through the `.dt` attribute for applicable DataArrays. + + Examples + -------- + >>> import pandas as pd + >>> import xarray as xr + >>> dates = pd.timedelta_range(start="1 day", freq="6H", periods=20) + >>> ts = xr.DataArray(dates, dims=('time')) + >>> ts + + array([ 86400000000000, 108000000000000, 129600000000000, 151200000000000, + 172800000000000, 194400000000000, 216000000000000, 237600000000000, + 259200000000000, 280800000000000, 302400000000000, 324000000000000, + 345600000000000, 367200000000000, 388800000000000, 410400000000000, + 432000000000000, 453600000000000, 475200000000000, 496800000000000], + dtype='timedelta64[ns]') + Coordinates: + * time (time) timedelta64[ns] 1 days 00:00:00 ... 5 days 18:00:00 + >>> ts.dt + + >>> ts.dt.days + + array([1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5]) + Coordinates: + * time (time) timedelta64[ns] 1 days 00:00:00 ... 5 days 18:00:00 + >>> ts.dt.microseconds + + array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + Coordinates: + * time (time) timedelta64[ns] 1 days 00:00:00 ... 5 days 18:00:00 + >>> ts.dt.seconds + + array([ 0, 21600, 43200, 64800, 0, 21600, 43200, 64800, 0, + 21600, 43200, 64800, 0, 21600, 43200, 64800, 0, 21600, + 43200, 64800]) + Coordinates: + * time (time) timedelta64[ns] 1 days 00:00:00 ... 5 days 18:00:00 + """ + + days = Properties._tslib_field_accessor( + "days", "Number of days for each element.", np.int64 + ) + seconds = Properties._tslib_field_accessor( + "seconds", + "Number of seconds (>= 0 and less than 1 day) for each element.", + np.int64, + ) + microseconds = Properties._tslib_field_accessor( + "microseconds", + "Number of microseconds (>= 0 and less than 1 second) for each element.", + np.int64, + ) + nanoseconds = Properties._tslib_field_accessor( + "nanoseconds", + "Number of nanoseconds (>= 0 and less than 1 microsecond) for each element.", + np.int64, + ) + + +class CombinedDatetimelikeAccessor(DatetimeAccessor, TimedeltaAccessor): + def __new__(cls, obj): + # CombinedDatetimelikeAccessor isn't really instatiated. Instead + # we need to choose which parent (datetime or timedelta) is + # appropriate. Since we're checking the dtypes anyway, we'll just + # do all the validation here. + if not _contains_datetime_like_objects(obj): + raise TypeError( + "'.dt' accessor only available for " + "DataArray with datetime64 timedelta64 dtype or " + "for arrays containing cftime datetime " + "objects." + ) + + if is_np_timedelta_like(obj.dtype): + return TimedeltaAccessor(obj) + else: + return DatetimeAccessor(obj) diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index 8838e71e6ca..6a975b948eb 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -854,12 +854,10 @@ def replace(self, pat, repl, n=-1, case=None, flags=0, regex=True): ---------- pat : string or compiled regex String can be a character sequence or regular expression. - repl : string or callable Replacement string or a callable. The callable is passed the regex match object and must return a replacement string to be used. See :func:`re.sub`. - n : int, default -1 (all) Number of replacements to make from start case : boolean, default None @@ -873,7 +871,7 @@ def replace(self, pat, repl, n=-1, case=None, flags=0, regex=True): - If True, assumes the passed-in pattern is a regular expression. - If False, treats the pattern as a literal string - Cannot be set to False if `pat` is a compiled regex or `repl` is - a callable. + a callable. Returns ------- diff --git a/xarray/core/combine.py b/xarray/core/combine.py index b9db30a9f92..3f6e0e79351 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -88,7 +88,7 @@ def _infer_concat_order_from_coords(datasets): # with the same value have the same coord values throughout. if any(index.size == 0 for index in indexes): raise ValueError("Cannot handle size zero dimensions") - first_items = pd.Index([index.take([0]) for index in indexes]) + first_items = pd.Index([index[0] for index in indexes]) # Sort datasets along dim # We want rank but with identical elements given identical @@ -115,11 +115,12 @@ def _infer_concat_order_from_coords(datasets): return combined_ids, concat_dims -def _check_shape_tile_ids(combined_tile_ids): +def _check_dimension_depth_tile_ids(combined_tile_ids): + """ + Check all tuples are the same length, i.e. check that all lists are + nested to the same depth. + """ tile_ids = combined_tile_ids.keys() - - # Check all tuples are the same length - # i.e. check that all lists are nested to the same depth nesting_depths = [len(tile_id) for tile_id in tile_ids] if not nesting_depths: nesting_depths = [0] @@ -128,8 +129,13 @@ def _check_shape_tile_ids(combined_tile_ids): "The supplied objects do not form a hypercube because" " sub-lists do not have consistent depths" ) + # return these just to be reused in _check_shape_tile_ids + return tile_ids, nesting_depths - # Check all lists along one dimension are same length + +def _check_shape_tile_ids(combined_tile_ids): + """Check all lists along one dimension are same length.""" + tile_ids, nesting_depths = _check_dimension_depth_tile_ids(combined_tile_ids) for dim in range(nesting_depths[0]): indices_along_dim = [tile_id[dim] for tile_id in tile_ids] occurrences = Counter(indices_along_dim) @@ -536,7 +542,8 @@ def combine_by_coords( coords : {'minimal', 'different', 'all' or list of str}, optional As per the 'data_vars' kwarg, but for coordinate variables. fill_value : scalar, optional - Value to use for newly missing values + Value to use for newly missing values. If None, raises a ValueError if + the passed Datasets do not create a complete hypercube. join : {'outer', 'inner', 'left', 'right', 'exact'}, optional String indicating how to combine differing indexes (excluding concat_dim) in objects @@ -653,6 +660,15 @@ def combine_by_coords( temperature (y, x) float64 1.654 10.63 7.015 2.543 ... 12.46 2.22 15.96 precipitation (y, x) float64 0.2136 0.9974 0.7603 ... 0.6125 0.4654 0.5953 + >>> xr.combine_by_coords([x1, x2, x3]) + + Dimensions: (x: 6, y: 4) + Coordinates: + * x (x) int64 10 20 30 40 50 60 + * y (y) int64 0 1 2 3 + Data variables: + temperature (y, x) float64 1.654 10.63 7.015 nan ... 12.46 2.22 15.96 + precipitation (y, x) float64 0.2136 0.9974 0.7603 ... 0.6125 0.4654 0.5953 """ # Group by data vars @@ -667,7 +683,13 @@ def combine_by_coords( list(datasets_with_same_vars) ) - _check_shape_tile_ids(combined_ids) + if fill_value is None: + # check that datasets form complete hypercube + _check_shape_tile_ids(combined_ids) + else: + # check only that all datasets have same dimension depth for these + # vars + _check_dimension_depth_tile_ids(combined_ids) # Concatenate along all of concat_dims one by one to create single ds concatenated = _combine_nd( diff --git a/xarray/core/common.py b/xarray/core/common.py index a74318b2f90..e908c69dd14 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1447,6 +1447,12 @@ def is_np_datetime_like(dtype: DTypeLike) -> bool: return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) +def is_np_timedelta_like(dtype: DTypeLike) -> bool: + """Check whether dtype is of the timedelta64 dtype. + """ + return np.issubdtype(dtype, np.timedelta64) + + def _contains_cftime_datetimes(array) -> bool: """Check if an array contains cftime.datetime objects """ diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 643c1137d6c..d2c5c32bc00 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -26,6 +26,7 @@ from . import duck_array_ops, utils from .alignment import deep_align from .merge import merge_coordinates_without_align +from .options import OPTIONS from .pycompat import dask_array_type from .utils import is_dict_like from .variable import Variable @@ -304,7 +305,7 @@ def _as_variables_or_variable(arg): def _unpack_dict_tuples( result_vars: Mapping[Hashable, Tuple[Variable, ...]], num_outputs: int ) -> Tuple[Dict[Hashable, Variable], ...]: - out = tuple({} for _ in range(num_outputs)) # type: ignore + out: Tuple[Dict[Hashable, Variable], ...] = tuple({} for _ in range(num_outputs)) for name, values in result_vars.items(): for value, results_dict in zip(values, out): results_dict[name] = value @@ -547,6 +548,7 @@ def apply_variable_ufunc( output_dtypes=None, output_sizes=None, keep_attrs=False, + meta=None, ): """Apply a ndarray level function over Variable and/or ndarray objects. """ @@ -589,6 +591,7 @@ def func(*arrays): signature, output_dtypes, output_sizes, + meta, ) elif dask == "allowed": @@ -647,7 +650,14 @@ def func(*arrays): def _apply_blockwise( - func, args, input_dims, output_dims, signature, output_dtypes, output_sizes=None + func, + args, + input_dims, + output_dims, + signature, + output_dtypes, + output_sizes=None, + meta=None, ): import dask.array @@ -719,6 +729,7 @@ def _apply_blockwise( dtype=dtype, concatenate=True, new_axes=output_sizes, + meta=meta, ) @@ -760,6 +771,7 @@ def apply_ufunc( dask: str = "forbidden", output_dtypes: Sequence = None, output_sizes: Mapping[Any, int] = None, + meta: Any = None, ) -> Any: """Apply a vectorized function for unlabeled arrays on xarray objects. @@ -856,6 +868,9 @@ def apply_ufunc( Optional mapping from dimension names to sizes for outputs. Only used if dask='parallelized' and new dimensions (not found on inputs) appear on outputs. + meta : optional + Size-0 object representing the type of array wrapped by dask array. Passed on to + ``dask.array.blockwise``. Returns ------- @@ -989,6 +1004,11 @@ def earth_mover_distance(first_samples, func = functools.partial(func, **kwargs) if vectorize: + if meta is None: + # set meta=np.ndarray by default for numpy vectorized functions + # work around dask bug computing meta with vectorized functions: GH5642 + meta = np.ndarray + if signature.all_core_dims: func = np.vectorize( func, otypes=output_dtypes, signature=signature.to_gufunc_string() @@ -1005,6 +1025,7 @@ def earth_mover_distance(first_samples, dask=dask, output_dtypes=output_dtypes, output_sizes=output_sizes, + meta=meta, ) if any(isinstance(a, GroupBy) for a in args): @@ -1019,6 +1040,7 @@ def earth_mover_distance(first_samples, dataset_fill_value=dataset_fill_value, keep_attrs=keep_attrs, dask=dask, + meta=meta, ) return apply_groupby_func(this_apply, *args) elif any(is_dict_like(a) for a in args): @@ -1175,6 +1197,11 @@ def dot(*arrays, dims=None, **kwargs): subscripts = ",".join(subscripts_list) subscripts += "->..." + "".join([dim_map[d] for d in output_core_dims[0]]) + join = OPTIONS["arithmetic_join"] + # using "inner" emulates `(a * b).sum()` for all joins (except "exact") + if join != "exact": + join = "inner" + # subscripts should be passed to np.einsum as arg, not as kwargs. We need # to construct a partial function for apply_ufunc to work. func = functools.partial(duck_array_ops.einsum, subscripts, **kwargs) @@ -1183,6 +1210,7 @@ def dot(*arrays, dims=None, **kwargs): *arrays, input_core_dims=input_core_dims, output_core_dims=output_core_dims, + join=join, dask="allowed", ) return result.transpose(*[d for d in all_dims if d in result.dims]) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 5ccbfa3f2b4..96b4be15d1b 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -93,12 +93,14 @@ def concat( those of the first object with that dimension. Indexes for the same dimension must have the same size in all objects. - indexers, mode, concat_over : deprecated - Returns ------- concatenated : type of objs + Notes + ----- + Each concatenated Variable preserves corresponding ``attrs`` from the first element of ``objs``. + See also -------- merge @@ -192,7 +194,23 @@ def process_subset_opt(opt, subset): for k in getattr(datasets[0], subset): if k not in concat_over: equals[k] = None - variables = [ds.variables[k] for ds in datasets] + + variables = [] + for ds in datasets: + if k in ds.variables: + variables.append(ds.variables[k]) + + if len(variables) == 1: + # coords="different" doesn't make sense when only one object + # contains a particular variable. + break + elif len(variables) != len(datasets) and opt == "different": + raise ValueError( + f"{k!r} not present in all datasets and coords='different'. " + f"Either add {k!r} to datasets where it is missing or " + "specify coords='minimal'." + ) + # first check without comparing values i.e. no computes for var in variables[1:]: equals[k] = getattr(variables[0], compat)( diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py index c3dbdd27098..05f750a1355 100644 --- a/xarray/core/dask_array_compat.py +++ b/xarray/core/dask_array_compat.py @@ -1,8 +1,16 @@ from distutils.version import LooseVersion +from typing import Iterable -import dask.array as da import numpy as np -from dask import __version__ as dask_version + +from .pycompat import dask_array_type + +try: + import dask.array as da + from dask import __version__ as dask_version +except ImportError: + dask_version = "0.0.0" + da = None if LooseVersion(dask_version) >= LooseVersion("2.0.0"): meta_from_array = da.utils.meta_from_array @@ -30,7 +38,7 @@ def meta_from_array(x, ndim=None, dtype=None): """ # If using x._meta, x must be a Dask Array, some libraries (e.g. zarr) # implement a _meta attribute that are incompatible with Dask Array._meta - if hasattr(x, "_meta") and isinstance(x, da.Array): + if hasattr(x, "_meta") and isinstance(x, dask_array_type): x = x._meta if dtype is None and x is None: @@ -89,3 +97,76 @@ def meta_from_array(x, ndim=None, dtype=None): meta = meta.astype(dtype) return meta + + +if LooseVersion(dask_version) >= LooseVersion("2.8.1"): + median = da.median +else: + # Copied from dask v2.8.1 + # Used under the terms of Dask's license, see licenses/DASK_LICENSE. + def median(a, axis=None, keepdims=False): + """ + This works by automatically chunking the reduced axes to a single chunk + and then calling ``numpy.median`` function across the remaining dimensions + """ + + if axis is None: + raise NotImplementedError( + "The da.median function only works along an axis. " + "The full algorithm is difficult to do in parallel" + ) + + if not isinstance(axis, Iterable): + axis = (axis,) + + axis = [ax + a.ndim if ax < 0 else ax for ax in axis] + + a = a.rechunk({ax: -1 if ax in axis else "auto" for ax in range(a.ndim)}) + + result = a.map_blocks( + np.median, + axis=axis, + keepdims=keepdims, + drop_axis=axis if not keepdims else None, + chunks=[1 if ax in axis else c for ax, c in enumerate(a.chunks)] + if keepdims + else None, + ) + + return result + + +if LooseVersion(dask_version) > LooseVersion("2.9.0"): + nanmedian = da.nanmedian +else: + + def nanmedian(a, axis=None, keepdims=False): + """ + This works by automatically chunking the reduced axes to a single chunk + and then calling ``numpy.nanmedian`` function across the remaining dimensions + """ + + if axis is None: + raise NotImplementedError( + "The da.nanmedian function only works along an axis. " + "The full algorithm is difficult to do in parallel" + ) + + if not isinstance(axis, Iterable): + axis = (axis,) + + axis = [ax + a.ndim if ax < 0 else ax for ax in axis] + + a = a.rechunk({ax: -1 if ax in axis else "auto" for ax in range(a.ndim)}) + + result = a.map_blocks( + np.nanmedian, + axis=axis, + keepdims=keepdims, + drop_axis=axis if not keepdims else None, + chunks=[1 if ax in axis else c for ax, c in enumerate(a.chunks)] + if keepdims + else None, + ) + + return result diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 20de0cffbc2..062cc6342df 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1,3 +1,4 @@ +import datetime import functools import warnings from numbers import Number @@ -33,7 +34,7 @@ rolling, utils, ) -from .accessor_dt import DatetimeAccessor +from .accessor_dt import CombinedDatetimelikeAccessor from .accessor_str import StringAccessor from .alignment import ( _broadcast_helper, @@ -235,19 +236,6 @@ class DataArray(AbstractArray, DataWithCoords): Getting items from or doing mathematical operations with a DataArray always returns another DataArray. - - Attributes - ---------- - dims : tuple - Dimension names associated with this array. - values : numpy.ndarray - Access or modify DataArray values as a numpy array. - coords : dict-like - Dictionary of DataArray objects that label values along each dimension. - name : str or None - Name of this array. - attrs : dict - Dictionary for holding arbitrary metadata. """ _cache: Dict[str, Any] @@ -271,7 +259,7 @@ class DataArray(AbstractArray, DataWithCoords): _coarsen_cls = rolling.DataArrayCoarsen _resample_cls = resample.DataArrayResample - dt = property(DatetimeAccessor) + dt = property(CombinedDatetimelikeAccessor) def __init__( self, @@ -280,8 +268,6 @@ def __init__( dims: Union[Hashable, Sequence[Hashable], None] = None, name: Hashable = None, attrs: Mapping = None, - # deprecated parameters - encoding=None, # internal parameters indexes: Dict[Hashable, pd.Index] = None, fastpath: bool = False, @@ -326,20 +312,10 @@ def __init__( Attributes to assign to the new instance. By default, an empty attribute dictionary is initialized. """ - if encoding is not None: - warnings.warn( - "The `encoding` argument to `DataArray` is deprecated, and . " - "will be removed in 0.15. " - "Instead, specify the encoding when writing to disk or " - "set the `encoding` attribute directly.", - FutureWarning, - stacklevel=2, - ) if fastpath: variable = data assert dims is None assert attrs is None - assert encoding is None else: # try to fill in arguments from data if they weren't supplied if coords is None: @@ -361,13 +337,11 @@ def __init__( name = getattr(data, "name", None) if attrs is None and not isinstance(data, PANDAS_TYPES): attrs = getattr(data, "attrs", None) - if encoding is None: - encoding = getattr(data, "encoding", None) data = _check_data_shape(data, coords, dims) data = as_compatible_data(data) coords, dims = _infer_coords_and_dims(data.shape, coords, dims) - variable = Variable(dims, data, attrs, encoding, fastpath=True) + variable = Variable(dims, data, attrs, fastpath=True) indexes = dict( _extract_indexes_from_coords(coords) ) # needed for to_dataset @@ -1128,7 +1102,7 @@ def thin( **indexers_kwargs: Any, ) -> "DataArray": """Return a new DataArray whose data is given by each `n` value - along the specified dimension(s). Default `n` = 5 + along the specified dimension(s). See Also -------- @@ -1302,7 +1276,7 @@ def reindex( satisfy the equation ``abs(index[indexer] - target) <= tolerance``. fill_value : scalar, optional Value to use for newly missing values - **indexers_kwarg : {dim: indexer, ...}, optional + **indexers_kwargs : {dim: indexer, ...}, optional The keyword arguments form of ``indexers``. One of indexers or indexers_kwargs must be provided. @@ -1351,7 +1325,7 @@ def interp( values. kwargs: dictionary Additional keyword passed to scipy's interpolator. - ``**coords_kwarg`` : {dim: coordinate, ...}, optional + ``**coords_kwargs`` : {dim: coordinate, ...}, optional The keyword arguments form of ``coords``. One of coords or coords_kwargs must be provided. @@ -1493,8 +1467,7 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "DataArray": ---------- dims_dict : dict-like Dictionary whose keys are current dimension names and whose values - are new names. Each value must already be a coordinate on this - array. + are new names. Returns ------- @@ -1517,6 +1490,13 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "DataArray": Coordinates: x (y) >> arr.swap_dims({"x": "z"}) + + array([0, 1]) + Coordinates: + x (z) "DataArray": """Fill in NaNs by interpolating according to different methods. @@ -2094,7 +2076,7 @@ def interpolate_na( or None for no limit. This filling is done regardless of the size of the gap in the data. To only interpolate over gaps less than a given length, see ``max_gap``. - max_gap: int, float, str, pandas.Timedelta, numpy.timedelta64, default None. + max_gap: int, float, str, pandas.Timedelta, numpy.timedelta64, datetime.timedelta, default None. Maximum size of gap, a continuous sequence of NaNs, that will be filled. Use None for no limit. When interpolating along a datetime64 dimension and ``use_coordinate=True``, ``max_gap`` can be one of the following: @@ -2102,6 +2084,7 @@ def interpolate_na( - a string that is valid input for pandas.to_timedelta - a :py:class:`numpy.timedelta64` object - a :py:class:`pandas.Timedelta` object + - a :py:class:`datetime.timedelta` object Otherwise, ``max_gap`` must be an int or a float. Use of ``max_gap`` with unlabeled dimensions has not been implemented yet. Gap length is defined as the difference @@ -2375,7 +2358,7 @@ def to_dict(self, data: bool = True) -> dict: naming conventions. Converts all variables and attributes to native Python objects. - Useful for coverting to json. To avoid datetime incompatibility + Useful for converting to json. To avoid datetime incompatibility use decode_times=False kwarg in xarrray.open_dataset. Parameters @@ -2753,7 +2736,7 @@ def shift( Value to use for newly missing values **shifts_kwargs: The keyword arguments form of ``shifts``. - One of shifts or shifts_kwarg must be provided. + One of shifts or shifts_kwargs must be provided. Returns ------- @@ -2804,7 +2787,7 @@ def roll( deprecated and will change to False in a future version. Explicitly pass roll_coords to silence the warning. **shifts_kwargs : The keyword arguments form of ``shifts``. - One of shifts or shifts_kwarg must be provided. + One of shifts or shifts_kwargs must be provided. Returns ------- @@ -2990,7 +2973,7 @@ def quantile( See Also -------- - numpy.nanpercentile, pandas.Series.quantile, Dataset.quantile + numpy.nanquantile, pandas.Series.quantile, Dataset.quantile Examples -------- @@ -3000,8 +2983,6 @@ def quantile( ... coords={"x": [7, 9], "y": [1, 1.5, 2, 2.5]}, ... dims=("x", "y"), ... ) - - Single quantile >>> da.quantile(0) # or da.quantile(0, dim=...) array(0.7) @@ -3013,8 +2994,6 @@ def quantile( Coordinates: * y (y) float64 1.0 1.5 2.0 2.5 quantile float64 0.0 - - Multiple quantiles >>> da.quantile([0, 0.5, 1]) array([0.7, 3.4, 9.4]) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 5926fd4ff36..07bea6dac19 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1,4 +1,5 @@ import copy +import datetime import functools import sys import warnings @@ -64,6 +65,7 @@ default_indexes, isel_variable_and_index, propagate_indexes, + remove_unused_levels_categories, roll_index, ) from .indexing import is_fancy_indexer @@ -85,11 +87,16 @@ either_dict_or_kwargs, hashable, is_dict_like, - is_list_like, is_scalar, maybe_wrap_array, ) -from .variable import IndexVariable, Variable, as_variable, broadcast_variables +from .variable import ( + IndexVariable, + Variable, + as_variable, + assert_unique_multiindex_level_names, + broadcast_variables, +) if TYPE_CHECKING: from ..backends import AbstractDataStore, ZarrStore @@ -458,7 +465,6 @@ def __init__( data_vars: Mapping[Hashable, Any] = None, coords: Mapping[Hashable, Any] = None, attrs: Mapping[Hashable, Any] = None, - compat=None, ): """To load data from a file or file-like object, use the `open_dataset` function. @@ -508,18 +514,7 @@ def __init__( attrs : dict-like, optional Global attributes to save on this dataset. - compat : deprecated """ - if compat is not None: - warnings.warn( - "The `compat` argument to Dataset is deprecated and will be " - "removed in 0.15." - "Instead, use `merge` to control how variables are combined", - FutureWarning, - stacklevel=2, - ) - else: - compat = "broadcast_equals" # TODO(shoyer): expose indexes as a public argument in __init__ @@ -539,7 +534,7 @@ def __init__( coords = coords.variables variables, coord_names, dims, indexes = merge_data_and_coords( - data_vars, coords, compat=compat + data_vars, coords, compat="broadcast_equals" ) self._attrs = dict(attrs) if attrs is not None else None @@ -904,11 +899,11 @@ def _replace( if dims is not None: self._dims = dims if attrs is not _default: - self._attrs = attrs # type: ignore # FIXME need mypy 0.750 + self._attrs = attrs if indexes is not _default: - self._indexes = indexes # type: ignore # FIXME need mypy 0.750 + self._indexes = indexes if encoding is not _default: - self._encoding = encoding # type: ignore # FIXME need mypy 0.750 + self._encoding = encoding obj = self else: if variables is None: @@ -1748,7 +1743,10 @@ def maybe_chunk(name, var, chunks): if not chunks: chunks = None if var.ndim > 0: - token2 = tokenize(name, token if token else var._data) + # when rechunking by different amounts, make sure dask names change + # by provinding chunks as an input to tokenize. + # subtle bugs result otherwise. see GH3350 + token2 = tokenize(name, token if token else var._data, chunks) name2 = f"{name_prefix}{name}-{token2}" return var.chunk(chunks, name=name2, lock=lock) else: @@ -1887,7 +1885,7 @@ def isel( drop : bool, optional If ``drop=True``, drop coordinates variables indexed by integers instead of making them scalar. - **indexers_kwarg : {dim: indexer, ...}, optional + **indexers_kwargs : {dim: indexer, ...}, optional The keyword arguments form of ``indexers``. One of indexers or indexers_kwargs must be provided. @@ -2033,7 +2031,7 @@ def sel( drop : bool, optional If ``drop=True``, drop coordinates variables in `indexers` instead of making them scalar. - **indexers_kwarg : {dim: indexer, ...}, optional + **indexers_kwargs : {dim: indexer, ...}, optional The keyword arguments form of ``indexers``. One of indexers or indexers_kwargs must be provided. @@ -2168,7 +2166,7 @@ def thin( Parameters ---------- - indexers : dict or int, default: 5 + indexers : dict or int A dict with keys matching dimensions and integer values `n` or a single integer `n` applied over all dimensions. One of indexers or indexers_kwargs must be provided. @@ -2332,7 +2330,7 @@ def reindex( fill_value : scalar, optional Value to use for newly missing values sparse: use sparse-array. By default, False - **indexers_kwarg : {dim: indexer, ...}, optional + **indexers_kwargs : {dim: indexer, ...}, optional Keyword arguments in the same form as ``indexers``. One of indexers or indexers_kwargs must be provided. @@ -2547,7 +2545,7 @@ def interp( values. kwargs: dictionary, optional Additional keyword passed to scipy's interpolator. - **coords_kwarg : {dim: coordinate, ...}, optional + **coords_kwargs : {dim: coordinate, ...}, optional The keyword arguments form of ``coords``. One of coords or coords_kwargs must be provided. @@ -2780,6 +2778,7 @@ def rename( variables, coord_names, dims, indexes = self._rename_all( name_dict=name_dict, dims_dict=name_dict ) + assert_unique_multiindex_level_names(variables) return self._replace(variables, coord_names, dims=dims, indexes=indexes) def rename_dims( @@ -2791,7 +2790,8 @@ def rename_dims( ---------- dims_dict : dict-like, optional Dictionary whose keys are current dimension names and - whose values are the desired names. + whose values are the desired names. The desired names must + not be the name of an existing dimension or Variable in the Dataset. **dims, optional Keyword form of ``dims_dict``. One of dims_dict or dims must be provided. @@ -2809,12 +2809,17 @@ def rename_dims( DataArray.rename """ dims_dict = either_dict_or_kwargs(dims_dict, dims, "rename_dims") - for k in dims_dict: + for k, v in dims_dict.items(): if k not in self.dims: raise ValueError( "cannot rename %r because it is not a " "dimension in this dataset" % k ) + if v in self.dims or v in self: + raise ValueError( + f"Cannot rename {k} to {v} because {v} already exists. " + "Try using swap_dims instead." + ) variables, coord_names, sizes, indexes = self._rename_all( name_dict={}, dims_dict=dims_dict @@ -2868,8 +2873,7 @@ def swap_dims( ---------- dims_dict : dict-like Dictionary whose keys are current dimension names and whose values - are new names. Each value must already be a variable in the - dataset. + are new names. Returns ------- @@ -2898,6 +2902,16 @@ def swap_dims( Data variables: a (y) int64 5 7 b (y) float64 0.1 2.4 + >>> ds.swap_dims({"x": "z"}) + + Dimensions: (z: 2) + Coordinates: + x (z) "Dataset": index = self.get_index(dim) - index = index.remove_unused_levels() + index = remove_unused_levels_categories(index) full_idx = pd.MultiIndex.from_product(index.levels, names=index.names) # take a shortcut in case the MultiIndex was not modified. @@ -3525,7 +3539,7 @@ def update(self, other: "CoercibleMapping", inplace: bool = None) -> "Dataset": def merge( self, - other: "CoercibleMapping", + other: Union["CoercibleMapping", "DataArray"], inplace: bool = None, overwrite_vars: Union[Hashable, Iterable[Hashable]] = frozenset(), compat: str = "no_conflicts", @@ -3582,6 +3596,7 @@ def merge( If any variables conflict (see ``compat``). """ _check_inplace(inplace) + other = other.to_dataset() if isinstance(other, xr.DataArray) else other merge_result = dataset_merge_method( self, other, @@ -3664,7 +3679,7 @@ def drop(self, labels=None, dim=None, *, errors="raise", **labels_kwargs): raise ValueError("cannot specify dim and dict-like arguments.") labels = either_dict_or_kwargs(labels, labels_kwargs, "drop") - if dim is None and (is_list_like(labels) or is_scalar(labels)): + if dim is None and (is_scalar(labels) or isinstance(labels, Iterable)): warnings.warn( "dropping variables using `drop` will be deprecated; using drop_vars is encouraged.", PendingDeprecationWarning, @@ -3981,7 +3996,9 @@ def interpolate_na( method: str = "linear", limit: int = None, use_coordinate: Union[bool, Hashable] = True, - max_gap: Union[int, float, str, pd.Timedelta, np.timedelta64] = None, + max_gap: Union[ + int, float, str, pd.Timedelta, np.timedelta64, datetime.timedelta + ] = None, **kwargs: Any, ) -> "Dataset": """Fill in NaNs by interpolating according to different methods. @@ -4014,7 +4031,7 @@ def interpolate_na( or None for no limit. This filling is done regardless of the size of the gap in the data. To only interpolate over gaps less than a given length, see ``max_gap``. - max_gap: int, float, str, pandas.Timedelta, numpy.timedelta64, default None. + max_gap: int, float, str, pandas.Timedelta, numpy.timedelta64, datetime.timedelta, default None. Maximum size of gap, a continuous sequence of NaNs, that will be filled. Use None for no limit. When interpolating along a datetime64 dimension and ``use_coordinate=True``, ``max_gap`` can be one of the following: @@ -4022,6 +4039,7 @@ def interpolate_na( - a string that is valid input for pandas.to_timedelta - a :py:class:`numpy.timedelta64` object - a :py:class:`pandas.Timedelta` object + - a :py:class:`datetime.timedelta` object Otherwise, ``max_gap`` must be an int or a float. Use of ``max_gap`` with unlabeled dimensions has not been implemented yet. Gap length is defined as the difference @@ -4127,7 +4145,7 @@ def combine_first(self, other: "Dataset") -> "Dataset": Returns ------- - DataArray + Dataset """ out = ops.fillna(self, other, join="outer", dataset_join="outer") return out @@ -4447,22 +4465,19 @@ def to_dataframe(self): return self._to_dataframe(self.dims) def _set_sparse_data_from_dataframe( - self, dataframe: pd.DataFrame, dims: tuple, shape: Tuple[int, ...] + self, dataframe: pd.DataFrame, dims: tuple ) -> None: from sparse import COO idx = dataframe.index if isinstance(idx, pd.MultiIndex): - try: - codes = idx.codes - except AttributeError: - # deprecated since pandas 0.24 - codes = idx.labels - coords = np.stack([np.asarray(code) for code in codes], axis=0) + coords = np.stack([np.asarray(code) for code in idx.codes], axis=0) is_sorted = idx.is_lexsorted + shape = tuple(lev.size for lev in idx.levels) else: coords = np.arange(idx.size).reshape(1, -1) is_sorted = True + shape = (idx.size,) for name, series in dataframe.items(): # Cast to a NumPy array first, in case the Series is a pandas @@ -4487,14 +4502,16 @@ def _set_sparse_data_from_dataframe( self[name] = (dims, data) def _set_numpy_data_from_dataframe( - self, dataframe: pd.DataFrame, dims: tuple, shape: Tuple[int, ...] + self, dataframe: pd.DataFrame, dims: tuple ) -> None: idx = dataframe.index if isinstance(idx, pd.MultiIndex): # expand the DataFrame to include the product of all levels full_idx = pd.MultiIndex.from_product(idx.levels, names=idx.names) dataframe = dataframe.reindex(full_idx) - + shape = tuple(lev.size for lev in idx.levels) + else: + shape = (idx.size,) for name, series in dataframe.items(): data = np.asarray(series).reshape(shape) self[name] = (dims, data) @@ -4535,7 +4552,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> "Datas if not dataframe.columns.is_unique: raise ValueError("cannot convert DataFrame with non-unique columns") - idx = dataframe.index + idx = remove_unused_levels_categories(dataframe.index) + dataframe = dataframe.set_index(idx) obj = cls() if isinstance(idx, pd.MultiIndex): @@ -4545,17 +4563,15 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> "Datas ) for dim, lev in zip(dims, idx.levels): obj[dim] = (dim, lev) - shape = tuple(lev.size for lev in idx.levels) else: index_name = idx.name if idx.name is not None else "index" dims = (index_name,) obj[index_name] = (dims, idx) - shape = (idx.size,) if sparse: - obj._set_sparse_data_from_dataframe(dataframe, dims, shape) + obj._set_sparse_data_from_dataframe(dataframe, dims) else: - obj._set_numpy_data_from_dataframe(dataframe, dims, shape) + obj._set_numpy_data_from_dataframe(dataframe, dims) return obj def to_dask_dataframe(self, dim_order=None, set_index=False): @@ -4641,7 +4657,7 @@ def to_dict(self, data=True): conventions. Converts all variables and attributes to native Python objects - Useful for coverting to json. To avoid datetime incompatibility + Useful for converting to json. To avoid datetime incompatibility use decode_times=False kwarg in xarrray.open_dataset. Parameters @@ -4938,7 +4954,7 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): Value to use for newly missing values **shifts_kwargs: The keyword arguments form of ``shifts``. - One of shifts or shifts_kwarg must be provided. + One of shifts or shifts_kwargs must be provided. Returns ------- @@ -5158,7 +5174,7 @@ def quantile( See Also -------- - numpy.nanpercentile, pandas.Series.quantile, DataArray.quantile + numpy.nanquantile, pandas.Series.quantile, DataArray.quantile Examples -------- @@ -5167,8 +5183,6 @@ def quantile( ... {"a": (("x", "y"), [[0.7, 4.2, 9.4, 1.5], [6.5, 7.3, 2.6, 1.9]])}, ... coords={"x": [7, 9], "y": [1, 1.5, 2, 2.5]}, ... ) - - Single quantile >>> ds.quantile(0) # or ds.quantile(0, dim=...) Dimensions: () @@ -5184,8 +5198,6 @@ def quantile( quantile float64 0.0 Data variables: a (y) float64 0.7 4.2 2.6 1.5 - - Multiple quantiles >>> ds.quantile([0, 0.5, 1]) Dimensions: (quantile: 3) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index cf616acb485..bc2db93a0a8 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -11,7 +11,7 @@ import numpy as np import pandas as pd -from . import dask_array_ops, dtypes, npcompat, nputils +from . import dask_array_compat, dask_array_ops, dtypes, npcompat, nputils from .nputils import nanfirst, nanlast from .pycompat import dask_array_type @@ -37,7 +37,7 @@ def f(*args, **kwargs): dispatch_args = args[0] else: dispatch_args = args[array_args] - if any(isinstance(a, dask_array.Array) for a in dispatch_args): + if any(isinstance(a, dask_array_type) for a in dispatch_args): try: wrapped = getattr(dask_module, name) except AttributeError as e: @@ -121,6 +121,7 @@ def notnull(data): isin = _dask_or_eager_func("isin", array_args=slice(2)) take = _dask_or_eager_func("take") broadcast_to = _dask_or_eager_func("broadcast_to") +pad = _dask_or_eager_func("pad") _concatenate = _dask_or_eager_func("concatenate", list_of_args=True) _stack = _dask_or_eager_func("stack", list_of_args=True) @@ -189,8 +190,8 @@ def lazy_array_equiv(arr1, arr2): return False if ( dask_array - and isinstance(arr1, dask_array.Array) - and isinstance(arr2, dask_array.Array) + and isinstance(arr1, dask_array_type) + and isinstance(arr2, dask_array_type) ): # GH3068 if arr1.name == arr2.name: @@ -261,7 +262,10 @@ def where_method(data, cond, other=dtypes.NA): def fillna(data, other): - return where(isnull(data), other, data) + # we need to pass data first so pint has a chance of returning the + # correct unit + # TODO: revert after https://github.com/hgrecco/pint/issues/1019 is fixed + return where(notnull(data), data, other) def concatenate(arrays, axis=0): @@ -284,7 +288,7 @@ def _ignore_warnings_if(condition): yield -def _create_nan_agg_method(name, coerce_strings=False): +def _create_nan_agg_method(name, dask_module=dask_array, coerce_strings=False): from . import nanops def f(values, axis=None, skipna=None, **kwargs): @@ -301,7 +305,7 @@ def f(values, axis=None, skipna=None, **kwargs): nanname = "nan" + name func = getattr(nanops, nanname) else: - func = _dask_or_eager_func(name) + func = _dask_or_eager_func(name, dask_module=dask_module) try: return func(values, axis=axis, **kwargs) @@ -337,7 +341,7 @@ def f(values, axis=None, skipna=None, **kwargs): std.numeric_only = True var = _create_nan_agg_method("var") var.numeric_only = True -median = _create_nan_agg_method("median") +median = _create_nan_agg_method("median", dask_module=dask_array_compat) median.numeric_only = True prod = _create_nan_agg_method("prod") prod.numeric_only = True @@ -372,44 +376,100 @@ def _datetime_nanmin(array): def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float): - """Convert an array containing datetime-like data to an array of floats. + """Convert an array containing datetime-like data to numerical values. + + Convert the datetime array to a timedelta relative to an offset. Parameters ---------- - da : np.array - Input data - offset: Scalar with the same type of array or None - If None, subtract minimum values to reduce round off error - datetime_unit: None or any of {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', - 'us', 'ns', 'ps', 'fs', 'as'} - dtype: target dtype + da : array-like + Input data + offset: None, datetime or cftime.datetime + Datetime offset. If None, this is set by default to the array's minimum + value to reduce round off errors. + datetime_unit: {None, Y, M, W, D, h, m, s, ms, us, ns, ps, fs, as} + If not None, convert output to a given datetime unit. Note that some + conversions are not allowed due to non-linear relationships between units. + dtype: dtype + Output dtype. Returns ------- array + Numerical representation of datetime object relative to an offset. + + Notes + ----- + Some datetime unit conversions won't work, for example from days to years, even + though some calendars would allow for them (e.g. no_leap). This is because there + is no `cftime.timedelta` object. """ # TODO: make this function dask-compatible? + # Set offset to minimum if not given if offset is None: if array.dtype.kind in "Mm": offset = _datetime_nanmin(array) else: offset = min(array) + + # Compute timedelta object. + # For np.datetime64, this can silently yield garbage due to overflow. + # One option is to enforce 1970-01-01 as the universal offset. array = array - offset - if not hasattr(array, "dtype"): # scalar is converted to 0d-array + # Scalar is converted to 0d-array + if not hasattr(array, "dtype"): array = np.array(array) + # Convert timedelta objects to float by first converting to microseconds. if array.dtype.kind in "O": - # possibly convert object array containing datetime.timedelta - array = np.asarray(pd.Series(array.ravel())).reshape(array.shape) + return py_timedelta_to_float(array, datetime_unit or "ns").astype(dtype) - if datetime_unit: - array = array / np.timedelta64(1, datetime_unit) + # Convert np.NaT to np.nan + elif array.dtype.kind in "mM": - # convert np.NaT to np.nan - if array.dtype.kind in "mM": + # Convert to specified timedelta units. + if datetime_unit: + array = array / np.timedelta64(1, datetime_unit) return np.where(isnull(array), np.nan, array.astype(dtype)) - return array.astype(dtype) + + +def timedelta_to_numeric(value, datetime_unit="ns", dtype=float): + """Convert a timedelta-like object to numerical values. + + Parameters + ---------- + value : datetime.timedelta, numpy.timedelta64, pandas.Timedelta, str + Time delta representation. + datetime_unit : {Y, M, W, D, h, m, s, ms, us, ns, ps, fs, as} + The time units of the output values. Note that some conversions are not allowed due to + non-linear relationships between units. + dtype : type + The output data type. + + """ + import datetime as dt + + if isinstance(value, dt.timedelta): + out = py_timedelta_to_float(value, datetime_unit) + elif isinstance(value, np.timedelta64): + out = np_timedelta64_to_float(value, datetime_unit) + elif isinstance(value, pd.Timedelta): + out = pd_timedelta_to_float(value, datetime_unit) + elif isinstance(value, str): + try: + a = pd.to_timedelta(value) + except ValueError: + raise ValueError( + f"Could not convert {value!r} to timedelta64 using pandas.to_timedelta" + ) + return py_timedelta_to_float(a, datetime_unit) + else: + raise TypeError( + f"Expected value of type str, pandas.Timedelta, datetime.timedelta " + f"or numpy.timedelta64, but received {type(value).__name__}" + ) + return out.astype(dtype) def _to_pytimedelta(array, unit="us"): @@ -417,6 +477,40 @@ def _to_pytimedelta(array, unit="us"): return index.to_pytimedelta().reshape(array.shape) +def np_timedelta64_to_float(array, datetime_unit): + """Convert numpy.timedelta64 to float. + + Notes + ----- + The array is first converted to microseconds, which is less likely to + cause overflow errors. + """ + array = array.astype("timedelta64[ns]").astype(np.float64) + conversion_factor = np.timedelta64(1, "ns") / np.timedelta64(1, datetime_unit) + return conversion_factor * array + + +def pd_timedelta_to_float(value, datetime_unit): + """Convert pandas.Timedelta to float. + + Notes + ----- + Built on the assumption that pandas timedelta values are in nanoseconds, + which is also the numpy default resolution. + """ + value = value.to_timedelta64() + return np_timedelta64_to_float(value, datetime_unit) + + +def py_timedelta_to_float(array, datetime_unit): + """Convert a timedelta object to a float, possibly at a loss of resolution. + """ + array = np.asarray(array) + array = np.reshape([a.total_seconds() for a in array.ravel()], array.shape) * 1e6 + conversion_factor = np.timedelta64(1, "us") / np.timedelta64(1, datetime_unit) + return conversion_factor * array + + def mean(array, axis=None, skipna=None, **kwargs): """inhouse mean that can handle np.datetime64 or cftime.datetime dtypes""" diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 520fa9b9f1b..89246ff228d 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -500,6 +500,13 @@ def diff_dim_summary(a, b): def _diff_mapping_repr(a_mapping, b_mapping, compat, title, summarizer, col_width=None): + def is_array_like(value): + return ( + hasattr(value, "ndim") + and hasattr(value, "shape") + and hasattr(value, "dtype") + ) + def extra_items_repr(extra_keys, mapping, ab_side): extra_repr = [summarizer(k, mapping[k], col_width) for k in extra_keys] if extra_repr: @@ -522,7 +529,11 @@ def extra_items_repr(extra_keys, mapping, ab_side): is_variable = True except AttributeError: # compare attribute value - compatible = a_mapping[k] == b_mapping[k] + if is_array_like(a_mapping[k]) or is_array_like(b_mapping[k]): + compatible = array_equiv(a_mapping[k], b_mapping[k]) + else: + compatible = a_mapping[k] == b_mapping[k] + is_variable = False if not compatible: diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index cb8f6538820..f2a9ebac6eb 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -595,7 +595,7 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None): See Also -------- - numpy.nanpercentile, pandas.Series.quantile, Dataset.quantile, + numpy.nanquantile, pandas.Series.quantile, Dataset.quantile, DataArray.quantile Examples @@ -607,8 +607,6 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None): ... dims=("y", "y"), ... ) >>> ds = xr.Dataset({"a": da}) - - Single quantile >>> da.groupby("x").quantile(0) array([[0.7, 4.2, 0.7, 1.5], @@ -625,15 +623,12 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None): * y (y) int64 1 2 Data variables: a (y) float64 0.7 0.7 - - Multiple quantiles >>> da.groupby("x").quantile([0, 0.5, 1]) array([[[0.7 , 1. , 1.3 ], [4.2 , 6.3 , 8.4 ], [0.7 , 5.05, 9.4 ], [1.5 , 4.2 , 6.9 ]], - [[6.5 , 6.5 , 6.5 ], [7.3 , 7.3 , 7.3 ], [2.6 , 2.6 , 2.6 ], diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 8337a0f082a..06bf08cefd2 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -9,6 +9,26 @@ from .variable import Variable +def remove_unused_levels_categories(index): + """ + Remove unused levels from MultiIndex and unused categories from CategoricalIndex + """ + if isinstance(index, pd.MultiIndex): + index = index.remove_unused_levels() + # if it contains CategoricalIndex, we need to remove unused categories + # manually. See https://github.com/pandas-dev/pandas/issues/30846 + if any(isinstance(lev, pd.CategoricalIndex) for lev in index.levels): + levels = [] + for i, level in enumerate(index.levels): + if isinstance(level, pd.CategoricalIndex): + level = level[index.codes[i]].remove_unused_categories() + levels.append(level) + index = pd.MultiIndex.from_arrays(levels, names=index.names) + elif isinstance(index, pd.CategoricalIndex): + index = index.remove_unused_categories() + return index + + class Indexes(collections.abc.Mapping): """Immutable proxy for Dataset or DataArrary indexes.""" diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 8e851b39c3e..ab049a0a4b4 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -4,7 +4,7 @@ from collections import defaultdict from contextlib import suppress from datetime import timedelta -from typing import Any, Callable, Sequence, Tuple, Union +from typing import Any, Callable, Iterable, Sequence, Tuple, Union import numpy as np import pandas as pd @@ -175,6 +175,16 @@ def convert_label_indexer(index, label, index_name="", method=None, tolerance=No if label.ndim == 0: if isinstance(index, pd.MultiIndex): indexer, new_index = index.get_loc_level(label.item(), level=0) + elif isinstance(index, pd.CategoricalIndex): + if method is not None: + raise ValueError( + "'method' is not a valid kwarg when indexing using a CategoricalIndex." + ) + if tolerance is not None: + raise ValueError( + "'tolerance' is not a valid kwarg when indexing using a CategoricalIndex." + ) + indexer = index.get_loc(label.item()) else: indexer = index.get_loc( label.item(), method=method, tolerance=tolerance @@ -1304,6 +1314,24 @@ def __init__(self, array): self.array = array def __getitem__(self, key): + + if not isinstance(key, VectorizedIndexer): + # if possible, short-circuit when keys are effectively slice(None) + # This preserves dask name and passes lazy array equivalence checks + # (see duck_array_ops.lazy_array_equiv) + rewritten_indexer = False + new_indexer = [] + for idim, k in enumerate(key.tuple): + if isinstance(k, Iterable) and duck_array_ops.array_equiv( + k, np.arange(self.array.shape[idim]) + ): + new_indexer.append(slice(None)) + rewritten_indexer = True + else: + new_indexer.append(k) + if rewritten_indexer: + key = type(key)(tuple(new_indexer)) + if isinstance(key, BasicIndexer): return self.array[key.tuple] elif isinstance(key, VectorizedIndexer): diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 117fcaf8f81..40f010b3514 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -1,3 +1,4 @@ +import datetime as dt import warnings from functools import partial from numbers import Number @@ -9,7 +10,7 @@ from . import utils from .common import _contains_datetime_like_objects, ones_like from .computation import apply_ufunc -from .duck_array_ops import dask_array_type +from .duck_array_ops import dask_array_type, datetime_to_numeric, timedelta_to_numeric from .utils import OrderedSet, is_scalar from .variable import Variable, broadcast_variables @@ -207,52 +208,81 @@ def _apply_over_vars_with_dim(func, self, dim=None, **kwargs): def get_clean_interp_index(arr, dim: Hashable, use_coordinate: Union[str, bool] = True): - """get index to use for x values in interpolation. + """Return index to use for x values in interpolation or curve fitting. - If use_coordinate is True, the coordinate that shares the name of the - dimension along which interpolation is being performed will be used as the - x values. + Parameters + ---------- + arr : DataArray + Array to interpolate or fit to a curve. + dim : str + Name of dimension along which to fit. + use_coordinate : str or bool + If use_coordinate is True, the coordinate that shares the name of the + dimension along which interpolation is being performed will be used as the + x values. If False, the x values are set as an equally spaced sequence. + + Returns + ------- + Variable + Numerical values for the x-coordinates. - If use_coordinate is False, the x values are set as an equally spaced - sequence. + Notes + ----- + If indexing is along the time dimension, datetime coordinates are converted + to time deltas with respect to 1970-01-01. """ - if use_coordinate: - if use_coordinate is True: - index = arr.get_index(dim) - else: - index = arr.coords[use_coordinate] - if index.ndim != 1: - raise ValueError( - f"Coordinates used for interpolation must be 1D, " - f"{use_coordinate} is {index.ndim}D." - ) - index = index.to_index() - - # TODO: index.name is None for multiindexes - # set name for nice error messages below - if isinstance(index, pd.MultiIndex): - index.name = dim - - if not index.is_monotonic: - raise ValueError(f"Index {index.name!r} must be monotonically increasing") - - if not index.is_unique: - raise ValueError(f"Index {index.name!r} has duplicate values") - - # raise if index cannot be cast to a float (e.g. MultiIndex) - try: - index = index.values.astype(np.float64) - except (TypeError, ValueError): - # pandas raises a TypeError - # xarray/numpy raise a ValueError - raise TypeError( - f"Index {index.name!r} must be castable to float64 to support " - f"interpolation, got {type(index).__name__}." - ) - else: + # Question: If use_coordinate is a string, what role does `dim` play? + from xarray.coding.cftimeindex import CFTimeIndex + + if use_coordinate is False: axis = arr.get_axis_num(dim) - index = np.arange(arr.shape[axis], dtype=np.float64) + return np.arange(arr.shape[axis], dtype=np.float64) + + if use_coordinate is True: + index = arr.get_index(dim) + + else: # string + index = arr.coords[use_coordinate] + if index.ndim != 1: + raise ValueError( + f"Coordinates used for interpolation must be 1D, " + f"{use_coordinate} is {index.ndim}D." + ) + index = index.to_index() + + # TODO: index.name is None for multiindexes + # set name for nice error messages below + if isinstance(index, pd.MultiIndex): + index.name = dim + + if not index.is_monotonic: + raise ValueError(f"Index {index.name!r} must be monotonically increasing") + + if not index.is_unique: + raise ValueError(f"Index {index.name!r} has duplicate values") + + # Special case for non-standard calendar indexes + # Numerical datetime values are defined with respect to 1970-01-01T00:00:00 in units of nanoseconds + if isinstance(index, (CFTimeIndex, pd.DatetimeIndex)): + offset = type(index[0])(1970, 1, 1) + if isinstance(index, CFTimeIndex): + index = index.values + index = Variable( + data=datetime_to_numeric(index, offset=offset, datetime_unit="ns"), + dims=(dim,), + ) + + # raise if index cannot be cast to a float (e.g. MultiIndex) + try: + index = index.values.astype(np.float64) + except (TypeError, ValueError): + # pandas raises a TypeError + # xarray/numpy raise a ValueError + raise TypeError( + f"Index {index.name!r} must be castable to float64 to support " + f"interpolation, got {type(index).__name__}." + ) return index @@ -263,11 +293,13 @@ def interp_na( use_coordinate: Union[bool, str] = True, method: str = "linear", limit: int = None, - max_gap: Union[int, float, str, pd.Timedelta, np.timedelta64] = None, + max_gap: Union[int, float, str, pd.Timedelta, np.timedelta64, dt.timedelta] = None, **kwargs, ): """Interpolate values according to different methods. """ + from xarray.coding.cftimeindex import CFTimeIndex + if dim is None: raise NotImplementedError("dim is a required argument") @@ -281,26 +313,11 @@ def interp_na( if ( dim in self.indexes - and isinstance(self.indexes[dim], pd.DatetimeIndex) + and isinstance(self.indexes[dim], (pd.DatetimeIndex, CFTimeIndex)) and use_coordinate ): - if not isinstance(max_gap, (np.timedelta64, pd.Timedelta, str)): - raise TypeError( - f"Underlying index is DatetimeIndex. Expected max_gap of type str, pandas.Timedelta or numpy.timedelta64 but received {max_type}" - ) - - if isinstance(max_gap, str): - try: - max_gap = pd.to_timedelta(max_gap) - except ValueError: - raise ValueError( - f"Could not convert {max_gap!r} to timedelta64 using pandas.to_timedelta" - ) - - if isinstance(max_gap, pd.Timedelta): - max_gap = np.timedelta64(max_gap.value, "ns") - - max_gap = np.timedelta64(max_gap, "ns").astype(np.float64) + # Convert to float + max_gap = timedelta_to_numeric(max_gap) if not use_coordinate: if not isinstance(max_gap, (Number, np.number)): diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index f70e96217e8..f9989c2c8c9 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -6,8 +6,10 @@ try: import dask.array as dask_array + from . import dask_array_compat except ImportError: dask_array = None + dask_array_compat = None # type: ignore def _replace_nan(a, val): @@ -141,7 +143,15 @@ def nanmean(a, axis=None, dtype=None, out=None): def nanmedian(a, axis=None, out=None): - return _dask_or_eager_func("nanmedian", eager_module=nputils)(a, axis=axis) + # The dask algorithm works by rechunking to one chunk along axis + # Make sure we trigger the dask error when passing all dimensions + # so that we don't rechunk the entire array to one chunk and + # possibly blow memory + if axis is not None and len(np.atleast_1d(axis)) == a.ndim: + axis = None + return _dask_or_eager_func( + "nanmedian", dask_module=dask_array_compat, eager_module=nputils + )(a, axis=axis) def _nanvar_object(value, axis=None, ddof=0, keepdims=False, **kwargs): diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index 3fe2c254b0f..cf189e471cc 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -2,6 +2,7 @@ import numpy as np import pandas as pd +from numpy.core.multiarray import normalize_axis_index try: import bottleneck as bn @@ -13,15 +14,6 @@ _USE_BOTTLENECK = False -def _validate_axis(data, axis): - ndim = data.ndim - if not -ndim <= axis < ndim: - raise IndexError(f"axis {axis!r} out of bounds [-{ndim}, {ndim})") - if axis < 0: - axis += ndim - return axis - - def _select_along_axis(values, idx, axis): other_ind = np.ix_(*[np.arange(s) for s in idx.shape]) sl = other_ind[:axis] + (idx,) + other_ind[axis:] @@ -29,13 +21,13 @@ def _select_along_axis(values, idx, axis): def nanfirst(values, axis): - axis = _validate_axis(values, axis) + axis = normalize_axis_index(axis, values.ndim) idx_first = np.argmax(~pd.isnull(values), axis=axis) return _select_along_axis(values, idx_first, axis) def nanlast(values, axis): - axis = _validate_axis(values, axis) + axis = normalize_axis_index(axis, values.ndim) rev = (slice(None),) * axis + (slice(None, None, -1),) idx_last = -1 - np.argmax(~pd.isnull(values)[rev], axis=axis) return _select_along_axis(values, idx_last, axis) @@ -186,7 +178,7 @@ def _rolling_window(a, window, axis=-1): This function is taken from https://github.com/numpy/numpy/pull/31 but slightly modified to accept axis option. """ - axis = _validate_axis(a, axis) + axis = normalize_axis_index(axis, a.ndim) a = np.swapaxes(a, axis, -1) if window < 1: diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index fbb5ef94ca2..facfa06b23c 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -7,11 +7,13 @@ except ImportError: pass +import collections import itertools import operator from typing import ( Any, Callable, + DefaultDict, Dict, Hashable, Mapping, @@ -152,6 +154,48 @@ def map_blocks( -------- dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks, xarray.DataArray.map_blocks + + Examples + -------- + + Calculate an anomaly from climatology using ``.groupby()``. Using + ``xr.map_blocks()`` allows for parallel operations with knowledge of ``xarray``, + its indices, and its methods like ``.groupby()``. + + >>> def calculate_anomaly(da, groupby_type='time.month'): + ... # Necessary workaround to xarray's check with zero dimensions + ... # https://github.com/pydata/xarray/issues/3575 + ... if sum(da.shape) == 0: + ... return da + ... gb = da.groupby(groupby_type) + ... clim = gb.mean(dim='time') + ... return gb - clim + >>> time = xr.cftime_range('1990-01', '1992-01', freq='M') + >>> np.random.seed(123) + >>> array = xr.DataArray(np.random.rand(len(time)), + ... dims="time", coords=[time]).chunk() + >>> xr.map_blocks(calculate_anomaly, array).compute() + + array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862, + 0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714, + -0.19063865, 0.0590131 , -0.12894847, -0.11323072, 0.0855964 , + 0.09334032, -0.26848862, -0.12382735, -0.22460641, -0.07650108, + 0.07673453, 0.22865714, 0.19063865, -0.0590131 ]) + Coordinates: + * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + + Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments + to the function being applied in ``xr.map_blocks()``: + + >>> xr.map_blocks(calculate_anomaly, array, kwargs={'groupby_type': 'time.year'}) + + array([ 0.15361741, -0.25671244, -0.31600032, 0.008463 , 0.1766172 , + -0.11974531, 0.43791243, 0.14197797, -0.06191987, -0.15073425, + -0.19967375, 0.18619794, -0.05100474, -0.42989909, -0.09153273, + 0.24841842, -0.30708526, -0.31412523, 0.04197439, 0.0422506 , + 0.14482397, 0.35985481, 0.23487834, 0.12144652]) + Coordinates: + * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 """ def _wrapper(func, obj, to_array, args, kwargs): @@ -221,7 +265,12 @@ def _wrapper(func, obj, to_array, args, kwargs): indexes = {dim: dataset.indexes[dim] for dim in preserved_indexes} indexes.update({k: template.indexes[k] for k in new_indexes}) + # We're building a new HighLevelGraph hlg. We'll have one new layer + # for each variable in the dataset, which is the result of the + # func applied to the values. + graph: Dict[Any, Any] = {} + new_layers: DefaultDict[str, Dict[Any, Any]] = collections.defaultdict(dict) gname = "{}-{}".format( dask.utils.funcname(func), dask.base.tokenize(dataset, args, kwargs) ) @@ -310,9 +359,20 @@ def _wrapper(func, obj, to_array, args, kwargs): # unchunked dimensions in the input have one chunk in the result key += (0,) - graph[key] = (operator.getitem, from_wrapper, name) + # We're adding multiple new layers to the graph: + # The first new layer is the result of the computation on + # the array. + # Then we add one layer per variable, which extracts the + # result for that variable, and depends on just the first new + # layer. + new_layers[gname_l][key] = (operator.getitem, from_wrapper, name) + + hlg = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset]) - graph = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset]) + for gname_l, layer in new_layers.items(): + # This adds in the getitems for each variable in the dataset. + hlg.dependencies[gname_l] = {gname} + hlg.layers[gname_l] = layer result = Dataset(coords=indexes, attrs=template.attrs) for name, gname_l in var_key_map.items(): @@ -325,7 +385,7 @@ def _wrapper(func, obj, to_array, args, kwargs): var_chunks.append((len(indexes[dim]),)) data = dask.array.Array( - graph, name=gname_l, chunks=var_chunks, dtype=template[name].dtype + hlg, name=gname_l, chunks=var_chunks, dtype=template[name].dtype ) result[name] = (dims, data, template[name].attrs) diff --git a/xarray/core/resample.py b/xarray/core/resample.py index fb388490d06..2b3b7da6217 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -184,6 +184,7 @@ def map(self, func, shortcut=False, args=(), **kwargs): Apply uses heuristics (like `pandas.GroupBy.apply`) to figure out how to stack together the array. The rule is: + 1. If the dimension along which the group coordinate is defined is still in the first grouped array after applying `func`, then stack over this dimension. @@ -196,11 +197,13 @@ def map(self, func, shortcut=False, args=(), **kwargs): Callable to apply to each array. shortcut : bool, optional Whether or not to shortcut evaluation under the assumptions that: + (1) The action of `func` does not depend on any of the array metadata (attributes or coordinates) but only on the data and dimensions. (2) The action of `func` creates arrays with homogeneous metadata, that is, with the same dimensions and attributes. + If these conditions are satisfied `shortcut` provides significant speedup. This should be the case for many common groupby operations (e.g., applying numpy ufuncs). @@ -275,6 +278,7 @@ def map(self, func, args=(), shortcut=None, **kwargs): Apply uses heuristics (like `pandas.GroupBy.apply`) to figure out how to stack together the datasets. The rule is: + 1. If the dimension along which the group coordinate is defined is still in the first grouped item after applying `func`, then stack over this dimension. diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 6681375c18e..e335365d5ca 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -547,7 +547,12 @@ def __eq__(self, other) -> bool: return False def __hash__(self) -> int: - return hash((ReprObject, self._value)) + return hash((type(self), self._value)) + + def __dask_tokenize__(self): + from dask.base import normalize_token + + return normalize_token((type(self), self._value)) @contextlib.contextmanager diff --git a/xarray/core/variable.py b/xarray/core/variable.py index aa04cffb5ea..daa8678157b 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -346,7 +346,10 @@ def data(self): def data(self, data): data = as_compatible_data(data) if data.shape != self.shape: - raise ValueError("replacement data must match the Variable's shape") + raise ValueError( + f"replacement data must match the Variable's shape. " + f"replacement data has shape {data.shape}; Variable has shape {self.shape}" + ) self._data = data def load(self, **kwargs): @@ -739,7 +742,10 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA): data = as_indexable(self._data)[actual_indexer] mask = indexing.create_mask(indexer, self.shape, data) - data = duck_array_ops.where(mask, fill_value, data) + # we need to invert the mask in order to pass data first. This helps + # pint to choose the correct unit + # TODO: revert after https://github.com/hgrecco/pint/issues/1019 is fixed + data = duck_array_ops.where(np.logical_not(mask), data, fill_value) else: # array cannot be indexed along dimensions of size 0, so just # build the mask directly instead. @@ -1051,7 +1057,9 @@ def isel( invalid = indexers.keys() - set(self.dims) if invalid: - raise ValueError("dimensions %r do not exist" % invalid) + raise ValueError( + f"dimensions {invalid} do not exist. Expected one or more of {self.dims}" + ) key = tuple(indexers.get(dim, slice(None)) for dim in self.dims) return self[key] @@ -1096,24 +1104,16 @@ def _shift_one_dim(self, dim, count, fill_value=dtypes.NA): else: dtype = self.dtype - shape = list(self.shape) - shape[axis] = min(abs(count), shape[axis]) - - if isinstance(trimmed_data, dask_array_type): - chunks = list(trimmed_data.chunks) - chunks[axis] = (shape[axis],) - full = functools.partial(da.full, chunks=chunks) - else: - full = np.full - - filler = full(shape, fill_value, dtype=dtype) - - if count > 0: - arrays = [filler, trimmed_data] - else: - arrays = [trimmed_data, filler] + width = min(abs(count), self.shape[axis]) + dim_pad = (width, 0) if count >= 0 else (0, width) + pads = [(0, 0) if d != dim else dim_pad for d in self.dims] - data = duck_array_ops.concatenate(arrays, axis) + data = duck_array_ops.pad( + trimmed_data.astype(dtype), + pads, + mode="constant", + constant_values=fill_value, + ) if isinstance(data, dask_array_type): # chunked data should come out with the same chunks; this makes @@ -1137,7 +1137,7 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): Value to use for newly missing values **shifts_kwargs: The keyword arguments form of ``shifts``. - One of shifts or shifts_kwarg must be provided. + One of shifts or shifts_kwargs must be provided. Returns ------- @@ -1245,7 +1245,7 @@ def roll(self, shifts=None, **shifts_kwargs): left. **shifts_kwargs: The keyword arguments form of ``shifts``. - One of shifts or shifts_kwarg must be provided. + One of shifts or shifts_kwargs must be provided. Returns ------- @@ -1622,8 +1622,9 @@ def concat(cls, variables, dim="concat_dim", positions=None, shortcut=False): if not shortcut: for var in variables: if var.dims != first_var.dims: - raise ValueError("inconsistent dimensions") - utils.remove_incompatible_items(attrs, var.attrs) + raise ValueError( + f"Variable has dimensions {list(var.dims)} but first Variable has dimensions {list(first_var.dims)}" + ) return cls(dims, data, attrs, encoding) @@ -1693,6 +1694,7 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None): This optional parameter specifies the interpolation method to use when the desired quantile lies between two data points ``i < j``: + * linear: ``i + (j - i) * fraction``, where ``fraction`` is the fractional part of the index surrounded by ``i`` and ``j``. @@ -1700,6 +1702,7 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None): * higher: ``j``. * nearest: ``i`` or ``j``, whichever is nearest. * midpoint: ``(i + j) / 2``. + keep_attrs : bool, optional If True, the variable's attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new @@ -1716,7 +1719,7 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None): See Also -------- - numpy.nanpercentile, pandas.Series.quantile, Dataset.quantile, + numpy.nanquantile, pandas.Series.quantile, Dataset.quantile, DataArray.quantile """ @@ -1736,7 +1739,7 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None): def _wrapper(npa, **kwargs): # move quantile axis to end. required for apply_ufunc - return np.moveaxis(np.nanpercentile(npa, **kwargs), 0, -1) + return np.moveaxis(np.nanquantile(npa, **kwargs), 0, -1) axis = np.arange(-1, -1 * len(dim) - 1, -1) result = apply_ufunc( @@ -1748,7 +1751,7 @@ def _wrapper(npa, **kwargs): output_dtypes=[np.float64], output_sizes={"quantile": len(q)}, dask="parallelized", - kwargs={"q": q * 100, "axis": axis, "interpolation": interpolation}, + kwargs={"q": q, "axis": axis, "interpolation": interpolation}, ) # for backward compatibility diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 7f13ba601fe..4f3268c1203 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -61,6 +61,10 @@ class FacetGrid: axes : numpy object array Contains axes in corresponding position, as returned from plt.subplots + col_labels : list + list of :class:`matplotlib.text.Text` instances corresponding to column titles. + row_labels : list + list of :class:`matplotlib.text.Text` instances corresponding to row titles. fig : matplotlib.Figure The figure containing all the axes name_dicts : numpy object array @@ -200,6 +204,8 @@ def __init__( self._ncol = ncol self._col_var = col self._col_wrap = col_wrap + self.row_labels = [None] * nrow + self.col_labels = [None] * ncol self._x_var = None self._y_var = None self._cmap_extend = None @@ -482,22 +488,32 @@ def set_titles(self, template="{coord} = {value}", maxchar=30, size=None, **kwar ax.set_title(title, size=size, **kwargs) else: # The row titles on the right edge of the grid - for ax, row_name in zip(self.axes[:, -1], self.row_names): + for index, (ax, row_name, handle) in enumerate( + zip(self.axes[:, -1], self.row_names, self.row_labels) + ): title = nicetitle(coord=self._row_var, value=row_name, maxchar=maxchar) - ax.annotate( - title, - xy=(1.02, 0.5), - xycoords="axes fraction", - rotation=270, - ha="left", - va="center", - **kwargs, - ) + if not handle: + self.row_labels[index] = ax.annotate( + title, + xy=(1.02, 0.5), + xycoords="axes fraction", + rotation=270, + ha="left", + va="center", + **kwargs, + ) + else: + handle.set_text(title) # The column titles on the top row - for ax, col_name in zip(self.axes[0, :], self.col_names): + for index, (ax, col_name, handle) in enumerate( + zip(self.axes[0, :], self.col_names, self.col_labels) + ): title = nicetitle(coord=self._col_var, value=col_name, maxchar=maxchar) - ax.set_title(title, size=size, **kwargs) + if not handle: + self.col_labels[index] = ax.set_title(title, size=size, **kwargs) + else: + handle.set_text(title) return self diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index d38c9765352..98131887e28 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -17,13 +17,11 @@ _ensure_plottable, _infer_interval_breaks, _infer_xy_labels, - _interval_to_double_bound_points, - _interval_to_mid_points, _process_cmap_cbar_kwargs, _rescale_imshow_rgb, + _resolve_intervals_1dplot, _resolve_intervals_2dplot, _update_axes, - _valid_other_type, get_axis, import_matplotlib_pyplot, label_from_attrs, @@ -296,29 +294,10 @@ def line( ax = get_axis(figsize, size, aspect, ax) xplt, yplt, hueplt, xlabel, ylabel, hue_label = _infer_line_data(darray, x, y, hue) - # Remove pd.Intervals if contained in xplt.values. - if _valid_other_type(xplt.values, [pd.Interval]): - # Is it a step plot? (see matplotlib.Axes.step) - if kwargs.get("linestyle", "").startswith("steps-"): - xplt_val, yplt_val = _interval_to_double_bound_points( - xplt.values, yplt.values - ) - # Remove steps-* to be sure that matplotlib is not confused - kwargs["linestyle"] = ( - kwargs["linestyle"] - .replace("steps-pre", "") - .replace("steps-post", "") - .replace("steps-mid", "") - ) - if kwargs["linestyle"] == "": - del kwargs["linestyle"] - else: - xplt_val = _interval_to_mid_points(xplt.values) - yplt_val = yplt.values - xlabel += "_center" - else: - xplt_val = xplt.values - yplt_val = yplt.values + # Remove pd.Intervals if contained in xplt.values and/or yplt.values. + xplt_val, yplt_val, xlabel, ylabel, kwargs = _resolve_intervals_1dplot( + xplt.values, yplt.values, xlabel, ylabel, kwargs + ) _ensure_plottable(xplt_val, yplt_val) @@ -360,6 +339,7 @@ def step(darray, *args, where="pre", linestyle=None, ls=None, **kwargs): ---------- where : {'pre', 'post', 'mid'}, optional, default 'pre' Define where the steps should be placed: + - 'pre': The y value is continued constantly to the left from every *x* position, i.e. the interval ``(x[i-1], x[i]]`` has the value ``y[i]``. @@ -367,12 +347,13 @@ def step(darray, *args, where="pre", linestyle=None, ls=None, **kwargs): every *x* position, i.e. the interval ``[x[i], x[i+1])`` has the value ``y[i]``. - 'mid': Steps occur half-way between the *x* positions. - Note that this parameter is ignored if the x coordinate consists of + + Note that this parameter is ignored if one coordinate consists of :py:func:`pandas.Interval` values, e.g. as a result of :py:func:`xarray.Dataset.groupby_bins`. In this case, the actual boundaries of the interval are used. - *args, **kwargs : optional + ``*args``, ``**kwargs`` : optional Additional arguments following :py:func:`xarray.plot.line` """ if where not in {"pre", "post", "mid"}: diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index b5bd96ab127..cb3bef6d409 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -21,26 +21,6 @@ ROBUST_PERCENTILE = 2.0 -def import_seaborn(): - """import seaborn and handle deprecation of apionly module""" - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - try: - import seaborn.apionly as sns - - if ( - w - and issubclass(w[-1].category, UserWarning) - and ("seaborn.apionly module" in str(w[-1].message)) - ): - raise ImportError - except ImportError: - import seaborn as sns - finally: - warnings.resetwarnings() - return sns - - _registered = False @@ -143,7 +123,7 @@ def _color_palette(cmap, n_colors): except ValueError: # ValueError happens when mpl doesn't like a colormap, try seaborn try: - from seaborn.apionly import color_palette + from seaborn import color_palette pal = color_palette(cmap, n_colors=n_colors) except (ValueError, ImportError): @@ -477,6 +457,42 @@ def _interval_to_double_bound_points(xarray, yarray): return xarray, yarray +def _resolve_intervals_1dplot(xval, yval, xlabel, ylabel, kwargs): + """ + Helper function to replace the values of x and/or y coordinate arrays + containing pd.Interval with their mid-points or - for step plots - double + points which double the length. + """ + + # Is it a step plot? (see matplotlib.Axes.step) + if kwargs.get("linestyle", "").startswith("steps-"): + + # Convert intervals to double points + if _valid_other_type(np.array([xval, yval]), [pd.Interval]): + raise TypeError("Can't step plot intervals against intervals.") + if _valid_other_type(xval, [pd.Interval]): + xval, yval = _interval_to_double_bound_points(xval, yval) + if _valid_other_type(yval, [pd.Interval]): + yval, xval = _interval_to_double_bound_points(yval, xval) + + # Remove steps-* to be sure that matplotlib is not confused + del kwargs["linestyle"] + + # Is it another kind of plot? + else: + + # Convert intervals to mid points and adjust labels + if _valid_other_type(xval, [pd.Interval]): + xval = _interval_to_mid_points(xval) + xlabel += "_center" + if _valid_other_type(yval, [pd.Interval]): + yval = _interval_to_mid_points(yval) + ylabel += "_center" + + # return converted arguments + return xval, yval, xlabel, ylabel, kwargs + + def _resolve_intervals_2dplot(val, func_name): """ Helper function to replace the values of a coordinate array containing @@ -571,10 +587,6 @@ def _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params): def _rescale_imshow_rgb(darray, vmin, vmax, robust): assert robust or vmin is not None or vmax is not None - # TODO: remove when min numpy version is bumped to 1.13 - # There's a cyclic dependency via DataArray, so we can't import from - # xarray.ufuncs in global scope. - from xarray.ufuncs import maximum, minimum # Calculate vmin and vmax automatically for `robust=True` if robust: @@ -603,9 +615,7 @@ def _rescale_imshow_rgb(darray, vmin, vmax, robust): # After scaling, downcast to 32-bit float. This substantially reduces # memory usage after we hand `darray` off to matplotlib. darray = ((darray.astype("f8") - vmin) / (vmax - vmin)).astype("f4") - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "xarray.ufuncs", PendingDeprecationWarning) - return minimum(maximum(darray, 0), 1) + return np.minimum(np.maximum(darray, 0), 1) def _update_axes( diff --git a/xarray/testing.py b/xarray/testing.py index 5c3ca8a3cca..ac189f7e023 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -10,6 +10,13 @@ from xarray.core.indexes import default_indexes from xarray.core.variable import IndexVariable, Variable +__all__ = ( + "assert_allclose", + "assert_chunks_equal", + "assert_equal", + "assert_identical", +) + def _decode_string_data(data): if data.dtype.kind == "S": diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 6592360cdf2..df86b5715e9 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -16,7 +16,6 @@ from xarray.core.duck_array_ops import allclose_or_equiv # noqa: F401 from xarray.core.indexing import ExplicitlyIndexed from xarray.core.options import set_options -from xarray.plot.utils import import_seaborn # import mpl and change the backend before other mpl imports try: @@ -71,6 +70,7 @@ def LooseVersion(vstring): has_iris, requires_iris = _importorskip("iris") has_cfgrib, requires_cfgrib = _importorskip("cfgrib") has_numbagg, requires_numbagg = _importorskip("numbagg") +has_seaborn, requires_seaborn = _importorskip("seaborn") has_sparse, requires_sparse = _importorskip("sparse") # some special cases @@ -78,12 +78,6 @@ def LooseVersion(vstring): requires_scipy_or_netCDF4 = pytest.mark.skipif( not has_scipy_or_netCDF4, reason="requires scipy or netCDF4" ) -try: - import_seaborn() - has_seaborn = True -except ImportError: - has_seaborn = False -requires_seaborn = pytest.mark.skipif(not has_seaborn, reason="requires seaborn") # change some global options for tests set_options(warn_for_unclosed_files=True) diff --git a/xarray/tests/test_accessor_dt.py b/xarray/tests/test_accessor_dt.py index 5fe5b8c3f59..f178720a6e1 100644 --- a/xarray/tests/test_accessor_dt.py +++ b/xarray/tests/test_accessor_dt.py @@ -11,6 +11,7 @@ requires_cftime, requires_dask, ) +from .test_dask import assert_chunks_equal, raise_if_dask_computes class TestDatetimeAccessor: @@ -37,24 +38,38 @@ def setup(self): name="data", ) - def test_field_access(self): - years = xr.DataArray( - self.times.year, name="year", coords=[self.times], dims=["time"] - ) - months = xr.DataArray( - self.times.month, name="month", coords=[self.times], dims=["time"] - ) - days = xr.DataArray( - self.times.day, name="day", coords=[self.times], dims=["time"] - ) - hours = xr.DataArray( - self.times.hour, name="hour", coords=[self.times], dims=["time"] + @pytest.mark.parametrize( + "field", + [ + "year", + "month", + "day", + "hour", + "minute", + "second", + "microsecond", + "nanosecond", + "week", + "weekofyear", + "dayofweek", + "weekday", + "dayofyear", + "quarter", + "is_month_start", + "is_month_end", + "is_quarter_start", + "is_quarter_end", + "is_year_start", + "is_year_end", + "is_leap_year", + ], + ) + def test_field_access(self, field): + expected = xr.DataArray( + getattr(self.times, field), name=field, coords=[self.times], dims=["time"] ) - - assert_equal(years, self.data.time.dt.year) - assert_equal(months, self.data.time.dt.month) - assert_equal(days, self.data.time.dt.day) - assert_equal(hours, self.data.time.dt.hour) + actual = getattr(self.data.time.dt, field) + assert_equal(expected, actual) def test_strftime(self): assert ( @@ -69,55 +84,74 @@ def test_not_datetime_type(self): nontime_data.time.dt @requires_dask - def test_dask_field_access(self): + @pytest.mark.parametrize( + "field", + [ + "year", + "month", + "day", + "hour", + "minute", + "second", + "microsecond", + "nanosecond", + "week", + "weekofyear", + "dayofweek", + "weekday", + "dayofyear", + "quarter", + "is_month_start", + "is_month_end", + "is_quarter_start", + "is_quarter_end", + "is_year_start", + "is_year_end", + "is_leap_year", + ], + ) + def test_dask_field_access(self, field): import dask.array as da - years = self.times_data.dt.year - months = self.times_data.dt.month - hours = self.times_data.dt.hour - days = self.times_data.dt.day - floor = self.times_data.dt.floor("D") - ceil = self.times_data.dt.ceil("D") - round = self.times_data.dt.round("D") - strftime = self.times_data.dt.strftime("%Y-%m-%d %H:%M:%S") + expected = getattr(self.times_data.dt, field) + + dask_times_arr = da.from_array(self.times_arr, chunks=(5, 5, 50)) + dask_times_2d = xr.DataArray( + dask_times_arr, coords=self.data.coords, dims=self.data.dims, name="data" + ) + + with raise_if_dask_computes(): + actual = getattr(dask_times_2d.dt, field) + + assert isinstance(actual.data, da.Array) + assert_chunks_equal(actual, dask_times_2d) + assert_equal(actual.compute(), expected.compute()) + + @requires_dask + @pytest.mark.parametrize( + "method, parameters", + [ + ("floor", "D"), + ("ceil", "D"), + ("round", "D"), + ("strftime", "%Y-%m-%d %H:%M:%S"), + ], + ) + def test_dask_accessor_method(self, method, parameters): + import dask.array as da + expected = getattr(self.times_data.dt, method)(parameters) dask_times_arr = da.from_array(self.times_arr, chunks=(5, 5, 50)) dask_times_2d = xr.DataArray( dask_times_arr, coords=self.data.coords, dims=self.data.dims, name="data" ) - dask_year = dask_times_2d.dt.year - dask_month = dask_times_2d.dt.month - dask_day = dask_times_2d.dt.day - dask_hour = dask_times_2d.dt.hour - dask_floor = dask_times_2d.dt.floor("D") - dask_ceil = dask_times_2d.dt.ceil("D") - dask_round = dask_times_2d.dt.round("D") - dask_strftime = dask_times_2d.dt.strftime("%Y-%m-%d %H:%M:%S") - - # Test that the data isn't eagerly evaluated - assert isinstance(dask_year.data, da.Array) - assert isinstance(dask_month.data, da.Array) - assert isinstance(dask_day.data, da.Array) - assert isinstance(dask_hour.data, da.Array) - assert isinstance(dask_strftime.data, da.Array) - - # Double check that outcome chunksize is unchanged - dask_chunks = dask_times_2d.chunks - assert dask_year.data.chunks == dask_chunks - assert dask_month.data.chunks == dask_chunks - assert dask_day.data.chunks == dask_chunks - assert dask_hour.data.chunks == dask_chunks - assert dask_strftime.data.chunks == dask_chunks - - # Check the actual output from the accessors - assert_equal(years, dask_year.compute()) - assert_equal(months, dask_month.compute()) - assert_equal(days, dask_day.compute()) - assert_equal(hours, dask_hour.compute()) - assert_equal(floor, dask_floor.compute()) - assert_equal(ceil, dask_ceil.compute()) - assert_equal(round, dask_round.compute()) - assert_equal(strftime, dask_strftime.compute()) + + with raise_if_dask_computes(): + actual = getattr(dask_times_2d.dt, method)(parameters) + + assert isinstance(actual.data, da.Array) + assert_chunks_equal(actual, dask_times_2d) + assert_equal(actual.compute(), expected.compute()) def test_seasons(self): dates = pd.date_range(start="2000/01/01", freq="M", periods=12) @@ -140,12 +174,108 @@ def test_seasons(self): assert_array_equal(seasons.values, dates.dt.season.values) - def test_rounders(self): + @pytest.mark.parametrize( + "method, parameters", [("floor", "D"), ("ceil", "D"), ("round", "D")] + ) + def test_accessor_method(self, method, parameters): dates = pd.date_range("2014-01-01", "2014-05-01", freq="H") - xdates = xr.DataArray(np.arange(len(dates)), dims=["time"], coords=[dates]) - assert_array_equal(dates.floor("D").values, xdates.time.dt.floor("D").values) - assert_array_equal(dates.ceil("D").values, xdates.time.dt.ceil("D").values) - assert_array_equal(dates.round("D").values, xdates.time.dt.round("D").values) + xdates = xr.DataArray(dates, dims=["time"]) + expected = getattr(dates, method)(parameters) + actual = getattr(xdates.dt, method)(parameters) + assert_array_equal(expected, actual) + + +class TestTimedeltaAccessor: + @pytest.fixture(autouse=True) + def setup(self): + nt = 100 + data = np.random.rand(10, 10, nt) + lons = np.linspace(0, 11, 10) + lats = np.linspace(0, 20, 10) + self.times = pd.timedelta_range(start="1 day", freq="6H", periods=nt) + + self.data = xr.DataArray( + data, + coords=[lons, lats, self.times], + dims=["lon", "lat", "time"], + name="data", + ) + + self.times_arr = np.random.choice(self.times, size=(10, 10, nt)) + self.times_data = xr.DataArray( + self.times_arr, + coords=[lons, lats, self.times], + dims=["lon", "lat", "time"], + name="data", + ) + + def test_not_datetime_type(self): + nontime_data = self.data.copy() + int_data = np.arange(len(self.data.time)).astype("int8") + nontime_data["time"].values = int_data + with raises_regex(TypeError, "dt"): + nontime_data.time.dt + + @pytest.mark.parametrize( + "field", ["days", "seconds", "microseconds", "nanoseconds"] + ) + def test_field_access(self, field): + expected = xr.DataArray( + getattr(self.times, field), name=field, coords=[self.times], dims=["time"] + ) + actual = getattr(self.data.time.dt, field) + assert_equal(expected, actual) + + @pytest.mark.parametrize( + "method, parameters", [("floor", "D"), ("ceil", "D"), ("round", "D")] + ) + def test_accessor_methods(self, method, parameters): + dates = pd.timedelta_range(start="1 day", end="30 days", freq="6H") + xdates = xr.DataArray(dates, dims=["time"]) + expected = getattr(dates, method)(parameters) + actual = getattr(xdates.dt, method)(parameters) + assert_array_equal(expected, actual) + + @requires_dask + @pytest.mark.parametrize( + "field", ["days", "seconds", "microseconds", "nanoseconds"] + ) + def test_dask_field_access(self, field): + import dask.array as da + + expected = getattr(self.times_data.dt, field) + + dask_times_arr = da.from_array(self.times_arr, chunks=(5, 5, 50)) + dask_times_2d = xr.DataArray( + dask_times_arr, coords=self.data.coords, dims=self.data.dims, name="data" + ) + + with raise_if_dask_computes(): + actual = getattr(dask_times_2d.dt, field) + + assert isinstance(actual.data, da.Array) + assert_chunks_equal(actual, dask_times_2d) + assert_equal(actual, expected) + + @requires_dask + @pytest.mark.parametrize( + "method, parameters", [("floor", "D"), ("ceil", "D"), ("round", "D")] + ) + def test_dask_accessor_method(self, method, parameters): + import dask.array as da + + expected = getattr(self.times_data.dt, method)(parameters) + dask_times_arr = da.from_array(self.times_arr, chunks=(5, 5, 50)) + dask_times_2d = xr.DataArray( + dask_times_arr, coords=self.data.coords, dims=self.data.dims, name="data" + ) + + with raise_if_dask_computes(): + actual = getattr(dask_times_2d.dt, method)(parameters) + + assert isinstance(actual.data, da.Array) + assert_chunks_equal(actual, dask_times_2d) + assert_equal(actual.compute(), expected.compute()) _CFTIME_CALENDARS = [ diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 1e135ebd3e1..b7ba70ef6c4 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -33,10 +33,11 @@ from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding from xarray.backends.pydap_ import PydapDataStore from xarray.coding.variables import SerializationWarning +from xarray.conventions import encode_dataset_coordinates from xarray.core import indexing from xarray.core.options import set_options from xarray.core.pycompat import dask_array_type -from xarray.tests import mock +from xarray.tests import LooseVersion, mock from . import ( arm_xfail, @@ -75,9 +76,14 @@ pass try: + import dask import dask.array as da + + dask_version = dask.__version__ except ImportError: - pass + # needed for xfailed tests when dask < 2.4.0 + # remove when min dask > 2.4.0 + dask_version = "10.0" ON_WINDOWS = sys.platform == "win32" @@ -522,15 +528,35 @@ def test_roundtrip_coordinates(self): with self.roundtrip(original) as actual: assert_identical(original, actual) + original["foo"].encoding["coordinates"] = "y" + with self.roundtrip(original, open_kwargs={"decode_coords": False}) as expected: + # check roundtripping when decode_coords=False + with self.roundtrip( + expected, open_kwargs={"decode_coords": False} + ) as actual: + assert_identical(expected, actual) + def test_roundtrip_global_coordinates(self): - original = Dataset({"x": [2, 3], "y": ("a", [42]), "z": ("x", [4, 5])}) + original = Dataset( + {"foo": ("x", [0, 1])}, {"x": [2, 3], "y": ("a", [42]), "z": ("x", [4, 5])} + ) with self.roundtrip(original) as actual: assert_identical(original, actual) + # test that global "coordinates" is as expected + _, attrs = encode_dataset_coordinates(original) + assert attrs["coordinates"] == "y" + + # test warning when global "coordinates" is already set + original.attrs["coordinates"] = "foo" + with pytest.warns(SerializationWarning): + _, attrs = encode_dataset_coordinates(original) + assert attrs["coordinates"] == "foo" + def test_roundtrip_coordinates_with_space(self): original = Dataset(coords={"x": 0, "y z": 1}) expected = Dataset({"y z": 1}, {"x": 0}) - with pytest.warns(xr.SerializationWarning): + with pytest.warns(SerializationWarning): with self.roundtrip(original) as actual: assert_identical(expected, actual) @@ -810,6 +836,18 @@ def equals_latlon(obj): assert "coordinates" not in ds["lat"].attrs assert "coordinates" not in ds["lon"].attrs + original["temp"].encoding["coordinates"] = "lat" + with self.roundtrip(original) as actual: + assert_identical(actual, original) + original["precip"].encoding["coordinates"] = "lat" + with create_tmp_file() as tmp_file: + original.to_netcdf(tmp_file) + with open_dataset(tmp_file, decode_coords=True) as ds: + assert "lon" not in ds["temp"].encoding["coordinates"] + assert "lon" not in ds["precip"].encoding["coordinates"] + assert "coordinates" not in ds["lat"].encoding + assert "coordinates" not in ds["lon"].encoding + def test_roundtrip_endian(self): ds = Dataset( { @@ -1690,6 +1728,7 @@ def test_hidden_zarr_keys(self): with xr.decode_cf(store): pass + @pytest.mark.skipif(LooseVersion(dask_version) < "2.4", reason="dask GH5334") def test_write_persistence_modes(self): original = create_test_data() @@ -1754,6 +1793,7 @@ def test_encoding_kwarg_fixed_width_string(self): def test_dataset_caching(self): super().test_dataset_caching() + @pytest.mark.skipif(LooseVersion(dask_version) < "2.4", reason="dask GH5334") def test_append_write(self): ds, ds_to_append, _ = create_append_test_data() with self.create_zarr_target() as store_target: @@ -1830,6 +1870,7 @@ def test_check_encoding_is_consistent_after_append(self): xr.concat([ds, ds_to_append], dim="time"), ) + @pytest.mark.skipif(LooseVersion(dask_version) < "2.4", reason="dask GH5334") def test_append_with_new_variable(self): ds, ds_to_append, ds_with_new_var = create_append_test_data() @@ -2149,7 +2190,7 @@ class TestH5NetCDFData(NetCDF4Base): @contextlib.contextmanager def create_store(self): with create_tmp_file() as tmp_file: - yield backends.H5NetCDFStore(tmp_file, "w") + yield backends.H5NetCDFStore.open(tmp_file, "w") @pytest.mark.filterwarnings("ignore:complex dtypes are supported by h5py") @pytest.mark.parametrize( @@ -2312,6 +2353,27 @@ def test_dump_encodings_h5py(self): assert actual.x.encoding["compression"] == "lzf" assert actual.x.encoding["compression_opts"] is None + def test_already_open_dataset_group(self): + import h5netcdf + + with create_tmp_file() as tmp_file: + with nc4.Dataset(tmp_file, mode="w") as nc: + group = nc.createGroup("g") + v = group.createVariable("x", "int") + v[...] = 42 + + h5 = h5netcdf.File(tmp_file, mode="r") + store = backends.H5NetCDFStore(h5["g"]) + with open_dataset(store) as ds: + expected = Dataset({"x": ((), 42)}) + assert_identical(expected, ds) + + h5 = h5netcdf.File(tmp_file, mode="r") + store = backends.H5NetCDFStore(h5, group="g") + with open_dataset(store) as ds: + expected = Dataset({"x": ((), 42)}) + assert_identical(expected, ds) + @requires_h5netcdf class TestH5NetCDFFileObject(TestH5NetCDFData): @@ -2478,6 +2540,7 @@ def test_open_mfdataset_manyfiles( @requires_netCDF4 +@requires_dask def test_open_mfdataset_list_attr(): """ Case when an attribute of type list differs across the multiple files @@ -2799,6 +2862,42 @@ def test_attrs_mfdataset(self): with raises_regex(AttributeError, "no attribute"): actual.test2 + def test_open_mfdataset_attrs_file(self): + original = Dataset({"foo": ("x", np.random.randn(10))}) + with create_tmp_files(2) as (tmp1, tmp2): + ds1 = original.isel(x=slice(5)) + ds2 = original.isel(x=slice(5, 10)) + ds1.attrs["test1"] = "foo" + ds2.attrs["test2"] = "bar" + ds1.to_netcdf(tmp1) + ds2.to_netcdf(tmp2) + with open_mfdataset( + [tmp1, tmp2], concat_dim="x", combine="nested", attrs_file=tmp2 + ) as actual: + # attributes are inherited from the master file + assert actual.attrs["test2"] == ds2.attrs["test2"] + # attributes from ds1 are not retained, e.g., + assert "test1" not in actual.attrs + + def test_open_mfdataset_attrs_file_path(self): + original = Dataset({"foo": ("x", np.random.randn(10))}) + with create_tmp_files(2) as (tmp1, tmp2): + tmp1 = Path(tmp1) + tmp2 = Path(tmp2) + ds1 = original.isel(x=slice(5)) + ds2 = original.isel(x=slice(5, 10)) + ds1.attrs["test1"] = "foo" + ds2.attrs["test2"] = "bar" + ds1.to_netcdf(tmp1) + ds2.to_netcdf(tmp2) + with open_mfdataset( + [tmp1, tmp2], concat_dim="x", combine="nested", attrs_file=tmp2 + ) as actual: + # attributes are inherited from the master file + assert actual.attrs["test2"] == ds2.attrs["test2"] + # attributes from ds1 are not retained, e.g., + assert "test1" not in actual.attrs + def test_open_mfdataset_auto_combine(self): original = Dataset({"foo": ("x", np.random.randn(10)), "x": np.arange(10)}) with create_tmp_file() as tmp1: diff --git a/xarray/tests/test_coding.py b/xarray/tests/test_coding.py index 3e0474e7b60..0f191049284 100644 --- a/xarray/tests/test_coding.py +++ b/xarray/tests/test_coding.py @@ -1,10 +1,12 @@ from contextlib import suppress import numpy as np +import pandas as pd import pytest import xarray as xr from xarray.coding import variables +from xarray.conventions import decode_cf_variable, encode_cf_variable from . import assert_equal, assert_identical, requires_dask @@ -20,20 +22,36 @@ def test_CFMaskCoder_decode(): assert_identical(expected, encoded) -def test_CFMaskCoder_encode_missing_fill_values_conflict(): - original = xr.Variable( - ("x",), - [0.0, -1.0, 1.0], - encoding={"_FillValue": np.float32(1e20), "missing_value": np.float64(1e20)}, - ) - coder = variables.CFMaskCoder() - encoded = coder.encode(original) +encoding_with_dtype = { + "dtype": np.dtype("float64"), + "_FillValue": np.float32(1e20), + "missing_value": np.float64(1e20), +} +encoding_without_dtype = { + "_FillValue": np.float32(1e20), + "missing_value": np.float64(1e20), +} +CFMASKCODER_ENCODE_DTYPE_CONFLICT_TESTS = { + "numeric-with-dtype": ([0.0, -1.0, 1.0], encoding_with_dtype), + "numeric-without-dtype": ([0.0, -1.0, 1.0], encoding_without_dtype), + "times-with-dtype": (pd.date_range("2000", periods=3), encoding_with_dtype), +} + + +@pytest.mark.parametrize( + ("data", "encoding"), + CFMASKCODER_ENCODE_DTYPE_CONFLICT_TESTS.values(), + ids=list(CFMASKCODER_ENCODE_DTYPE_CONFLICT_TESTS.keys()), +) +def test_CFMaskCoder_encode_missing_fill_values_conflict(data, encoding): + original = xr.Variable(("x",), data, encoding=encoding) + encoded = encode_cf_variable(original) assert encoded.dtype == encoded.attrs["missing_value"].dtype assert encoded.dtype == encoded.attrs["_FillValue"].dtype with pytest.warns(variables.SerializationWarning): - roundtripped = coder.decode(coder.encode(original)) + roundtripped = decode_cf_variable("foo", encoded) assert_identical(roundtripped, original) diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index d012fb36c35..00c34940ce4 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -451,7 +451,7 @@ def test_cf_datetime_nan(num_dates, units, expected_list): warnings.filterwarnings("ignore", "All-NaN") actual = coding.times.decode_cf_datetime(num_dates, units) # use pandas because numpy will deprecate timezone-aware conversions - expected = pd.to_datetime(expected_list) + expected = pd.to_datetime(expected_list).to_numpy(dtype="datetime64[ns]") assert_array_equal(expected, actual) diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index cd26e7fb60b..eb2c6e1dbf7 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -22,7 +22,7 @@ _new_tile_id, ) -from . import assert_equal, assert_identical, raises_regex +from . import assert_equal, assert_identical, raises_regex, requires_cftime from .test_dataset import create_test_data @@ -365,9 +365,10 @@ def test_nested_concat(self): expected = Dataset({"x": ("a", [0, 1]), "y": ("a", [0, 1])}) assert_identical(expected, actual) - objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [0]})] - with pytest.raises(KeyError): - combine_nested(objs, concat_dim="x") + objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [1]})] + actual = combine_nested(objs, concat_dim="x") + expected = Dataset({"x": [0, 1], "y": [0]}) + assert_identical(expected, actual) @pytest.mark.parametrize( "join, expected", @@ -711,6 +712,22 @@ def test_check_for_impossible_ordering(self): ): combine_by_coords([ds1, ds0]) + def test_combine_by_coords_incomplete_hypercube(self): + # test that this succeeds with default fill_value + x1 = Dataset({"a": (("y", "x"), [[1]])}, coords={"y": [0], "x": [0]}) + x2 = Dataset({"a": (("y", "x"), [[1]])}, coords={"y": [1], "x": [0]}) + x3 = Dataset({"a": (("y", "x"), [[1]])}, coords={"y": [0], "x": [1]}) + actual = combine_by_coords([x1, x2, x3]) + expected = Dataset( + {"a": (("y", "x"), [[1, 1], [1, np.nan]])}, + coords={"y": [0, 1], "x": [0, 1]}, + ) + assert_identical(expected, actual) + + # test that this fails if fill_value is None + with pytest.raises(ValueError): + combine_by_coords([x1, x2, x3], fill_value=None) + @pytest.mark.filterwarnings( "ignore:In xarray version 0.15 `auto_combine` " "will be deprecated" @@ -877,3 +894,25 @@ def test_auto_combine_without_coords(self): objs = [Dataset({"foo": ("x", [0])}), Dataset({"foo": ("x", [1])})] with pytest.warns(FutureWarning, match="supplied do not have global"): auto_combine(objs) + + +@requires_cftime +def test_combine_by_coords_distant_cftime_dates(): + # Regression test for https://github.com/pydata/xarray/issues/3535 + import cftime + + time_1 = [cftime.DatetimeGregorian(4500, 12, 31)] + time_2 = [cftime.DatetimeGregorian(4600, 12, 31)] + time_3 = [cftime.DatetimeGregorian(5100, 12, 31)] + + da_1 = DataArray([0], dims=["time"], coords=[time_1], name="a").to_dataset() + da_2 = DataArray([1], dims=["time"], coords=[time_2], name="a").to_dataset() + da_3 = DataArray([2], dims=["time"], coords=[time_3], name="a").to_dataset() + + result = combine_by_coords([da_1, da_2, da_3]) + + expected_time = np.concatenate([time_1, time_2, time_3]) + expected = DataArray( + [0, 1, 2], dims=["time"], coords=[expected_time], name="a" + ).to_dataset() + assert_identical(result, expected) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 1f2634cc9b0..369903552ad 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -817,6 +817,24 @@ def test_vectorize_dask(): assert_identical(expected, actual) +@requires_dask +def test_vectorize_dask_new_output_dims(): + # regression test for GH3574 + data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) + func = lambda x: x[np.newaxis, ...] + expected = data_array.expand_dims("z") + actual = apply_ufunc( + func, + data_array.chunk({"x": 1}), + output_core_dims=[["z"]], + vectorize=True, + dask="parallelized", + output_dtypes=[float], + output_sizes={"z": 1}, + ).transpose(*expected.dims) + assert_identical(expected, actual) + + def test_output_wrong_number(): variable = xr.Variable("x", np.arange(10)) @@ -1043,6 +1061,60 @@ def test_dot(use_dask): pickle.loads(pickle.dumps(xr.dot(da_a))) +@pytest.mark.parametrize("use_dask", [True, False]) +def test_dot_align_coords(use_dask): + # GH 3694 + + if use_dask: + if not has_dask: + pytest.skip("test for dask.") + + a = np.arange(30 * 4).reshape(30, 4) + b = np.arange(30 * 4 * 5).reshape(30, 4, 5) + + # use partially overlapping coords + coords_a = {"a": np.arange(30), "b": np.arange(4)} + coords_b = {"a": np.arange(5, 35), "b": np.arange(1, 5)} + + da_a = xr.DataArray(a, dims=["a", "b"], coords=coords_a) + da_b = xr.DataArray(b, dims=["a", "b", "c"], coords=coords_b) + + if use_dask: + da_a = da_a.chunk({"a": 3}) + da_b = da_b.chunk({"a": 3}) + + # join="inner" is the default + actual = xr.dot(da_a, da_b) + # `dot` sums over the common dimensions of the arguments + expected = (da_a * da_b).sum(["a", "b"]) + xr.testing.assert_allclose(expected, actual) + + actual = xr.dot(da_a, da_b, dims=...) + expected = (da_a * da_b).sum() + xr.testing.assert_allclose(expected, actual) + + with xr.set_options(arithmetic_join="exact"): + with raises_regex(ValueError, "indexes along dimension"): + xr.dot(da_a, da_b) + + # NOTE: dot always uses `join="inner"` because `(a * b).sum()` yields the same for all + # join method (except "exact") + with xr.set_options(arithmetic_join="left"): + actual = xr.dot(da_a, da_b) + expected = (da_a * da_b).sum(["a", "b"]) + xr.testing.assert_allclose(expected, actual) + + with xr.set_options(arithmetic_join="right"): + actual = xr.dot(da_a, da_b) + expected = (da_a * da_b).sum(["a", "b"]) + xr.testing.assert_allclose(expected, actual) + + with xr.set_options(arithmetic_join="outer"): + actual = xr.dot(da_a, da_b) + expected = (da_a * da_b).sum(["a", "b"]) + xr.testing.assert_allclose(expected, actual) + + def test_where(): cond = xr.DataArray([True, False], dims="x") actual = xr.where(cond, 1, 0) diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 0661ebb7a38..bd99181a947 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -462,3 +462,37 @@ def test_concat_join_kwarg(self): for join in expected: actual = concat([ds1, ds2], join=join, dim="x") assert_equal(actual, expected[join].to_array()) + + +@pytest.mark.parametrize("attr1", ({"a": {"meta": [10, 20, 30]}}, {"a": [1, 2, 3]}, {})) +@pytest.mark.parametrize("attr2", ({"a": [1, 2, 3]}, {})) +def test_concat_attrs_first_variable(attr1, attr2): + + arrs = [ + DataArray([[1], [2]], dims=["x", "y"], attrs=attr1), + DataArray([[3], [4]], dims=["x", "y"], attrs=attr2), + ] + + concat_attrs = concat(arrs, "y").attrs + assert concat_attrs == attr1 + + +def test_concat_merge_single_non_dim_coord(): + da1 = DataArray([1, 2, 3], dims="x", coords={"x": [1, 2, 3], "y": 1}) + da2 = DataArray([4, 5, 6], dims="x", coords={"x": [4, 5, 6]}) + + expected = DataArray(range(1, 7), dims="x", coords={"x": range(1, 7), "y": 1}) + + for coords in ["different", "minimal"]: + actual = concat([da1, da2], "x", coords=coords) + assert_identical(actual, expected) + + with raises_regex(ValueError, "'y' is not present in all datasets."): + concat([da1, da2], dim="x", coords="all") + + da1 = DataArray([1, 2, 3], dims="x", coords={"x": [1, 2, 3], "y": 1}) + da2 = DataArray([4, 5, 6], dims="x", coords={"x": [4, 5, 6]}) + da3 = DataArray([7, 8, 9], dims="x", coords={"x": [7, 8, 9], "y": 1}) + for coords in ["different", "all"]: + with raises_regex(ValueError, "'y' not present in all datasets"): + concat([da1, da2, da3], dim="x") diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index 09002e252b4..acb2400ea04 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -136,6 +136,20 @@ def test_multidimensional_coordinates(self): # Should not have any global coordinates. assert "coordinates" not in attrs + def test_do_not_overwrite_user_coordinates(self): + orig = Dataset( + coords={"x": [0, 1, 2], "y": ("x", [5, 6, 7]), "z": ("x", [8, 9, 10])}, + data_vars={"a": ("x", [1, 2, 3]), "b": ("x", [3, 5, 6])}, + ) + orig["a"].encoding["coordinates"] = "y" + orig["b"].encoding["coordinates"] = "z" + enc, _ = conventions.encode_dataset_coordinates(orig) + assert enc["a"].attrs["coordinates"] == "y" + assert enc["b"].attrs["coordinates"] == "z" + orig["a"].attrs["coordinates"] = "foo" + with raises_regex(ValueError, "'coordinates' found in both attrs"): + conventions.encode_dataset_coordinates(orig) + @requires_dask def test_string_object_warning(self): original = Variable(("x",), np.array(["foo", "bar"], dtype=object)).chunk() diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index f3b10e3370c..8fb54c4ee84 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -216,8 +216,10 @@ def test_reduce(self): self.assertLazyAndAllClose(u.argmin(dim="x"), actual) self.assertLazyAndAllClose((u > 1).any(), (v > 1).any()) self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x")) - with raises_regex(NotImplementedError, "dask"): + with raises_regex(NotImplementedError, "only works along an axis"): v.median() + with raises_regex(NotImplementedError, "only works along an axis"): + v.median(v.dims) with raise_if_dask_computes(): v.reduce(duck_array_ops.mean) @@ -1081,7 +1083,7 @@ def func(obj): actual = xr.map_blocks(func, obj) expected = func(obj) assert_chunks_equal(expected.chunk(), actual) - xr.testing.assert_identical(actual.compute(), expected.compute()) + assert_identical(actual, expected) @pytest.mark.parametrize("obj", [make_da(), make_ds()]) @@ -1090,7 +1092,7 @@ def test_map_blocks_convert_args_to_list(obj): with raise_if_dask_computes(): actual = xr.map_blocks(operator.add, obj, [10]) assert_chunks_equal(expected.chunk(), actual) - xr.testing.assert_identical(actual.compute(), expected.compute()) + assert_identical(actual, expected) @pytest.mark.parametrize("obj", [make_da(), make_ds()]) @@ -1105,7 +1107,7 @@ def add_attrs(obj): with raise_if_dask_computes(): actual = xr.map_blocks(add_attrs, obj) - xr.testing.assert_identical(actual.compute(), expected.compute()) + assert_identical(actual, expected) def test_map_blocks_change_name(map_da): @@ -1118,7 +1120,7 @@ def change_name(obj): with raise_if_dask_computes(): actual = xr.map_blocks(change_name, map_da) - xr.testing.assert_identical(actual.compute(), expected.compute()) + assert_identical(actual, expected) @pytest.mark.parametrize("obj", [make_da(), make_ds()]) @@ -1127,7 +1129,7 @@ def test_map_blocks_kwargs(obj): with raise_if_dask_computes(): actual = xr.map_blocks(xr.full_like, obj, kwargs=dict(fill_value=np.nan)) assert_chunks_equal(expected.chunk(), actual) - xr.testing.assert_identical(actual.compute(), expected.compute()) + assert_identical(actual, expected) def test_map_blocks_to_array(map_ds): @@ -1135,7 +1137,7 @@ def test_map_blocks_to_array(map_ds): actual = xr.map_blocks(lambda x: x.to_array(), map_ds) # to_array does not preserve name, so cannot use assert_identical - assert_equal(actual.compute(), map_ds.to_array().compute()) + assert_equal(actual, map_ds.to_array()) @pytest.mark.parametrize( @@ -1154,7 +1156,7 @@ def test_map_blocks_da_transformations(func, map_da): with raise_if_dask_computes(): actual = xr.map_blocks(func, map_da) - assert_identical(actual.compute(), func(map_da).compute()) + assert_identical(actual, func(map_da)) @pytest.mark.parametrize( @@ -1173,7 +1175,7 @@ def test_map_blocks_ds_transformations(func, map_ds): with raise_if_dask_computes(): actual = xr.map_blocks(func, map_ds) - assert_identical(actual.compute(), func(map_ds).compute()) + assert_identical(actual, func(map_ds)) @pytest.mark.parametrize("obj", [make_da(), make_ds()]) @@ -1186,7 +1188,20 @@ def func(obj): expected = xr.map_blocks(func, obj) actual = obj.map_blocks(func) - assert_identical(expected.compute(), actual.compute()) + assert_identical(expected, actual) + + +def test_map_blocks_hlg_layers(): + # regression test for #3599 + ds = xr.Dataset( + { + "x": (("a",), dask.array.ones(10, chunks=(5,))), + "z": (("b",), dask.array.ones(10, chunks=(5,))), + } + ) + mapped = ds.map_blocks(lambda x: x) + + xr.testing.assert_equal(mapped, ds) def test_make_meta(map_ds): @@ -1375,3 +1390,58 @@ def test_lazy_array_equiv_merge(compat): xr.merge([da1, da3], compat=compat) with raise_if_dask_computes(max_computes=2): xr.merge([da1, da2 / 2], compat=compat) + + +@pytest.mark.filterwarnings("ignore::FutureWarning") # transpose_coords +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) +@pytest.mark.parametrize( + "transform", + [ + lambda a: a.assign_attrs(new_attr="anew"), + lambda a: a.assign_coords(cxy=a.cxy), + lambda a: a.copy(), + lambda a: a.isel(x=np.arange(a.sizes["x"])), + lambda a: a.isel(x=slice(None)), + lambda a: a.loc[dict(x=slice(None))], + lambda a: a.loc[dict(x=np.arange(a.sizes["x"]))], + lambda a: a.loc[dict(x=a.x)], + lambda a: a.sel(x=a.x), + lambda a: a.sel(x=a.x.values), + lambda a: a.transpose(...), + lambda a: a.squeeze(), # no dimensions to squeeze + lambda a: a.sortby("x"), # "x" is already sorted + lambda a: a.reindex(x=a.x), + lambda a: a.reindex_like(a), + lambda a: a.rename({"cxy": "cnew"}).rename({"cnew": "cxy"}), + lambda a: a.pipe(lambda x: x), + lambda a: xr.align(a, xr.zeros_like(a))[0], + # assign + # swap_dims + # set_index / reset_index + ], +) +def test_transforms_pass_lazy_array_equiv(obj, transform): + with raise_if_dask_computes(): + assert_equal(obj, transform(obj)) + + +def test_more_transforms_pass_lazy_array_equiv(map_da, map_ds): + with raise_if_dask_computes(): + assert_equal(map_ds.cxy.broadcast_like(map_ds.cxy), map_ds.cxy) + assert_equal(xr.broadcast(map_ds.cxy, map_ds.cxy)[0], map_ds.cxy) + assert_equal(map_ds.map(lambda x: x), map_ds) + assert_equal(map_ds.set_coords("a").reset_coords("a"), map_ds) + assert_equal(map_ds.update({"a": map_ds.a}), map_ds) + + # fails because of index error + # assert_equal( + # map_ds.rename_dims({"x": "xnew"}).rename_dims({"xnew": "x"}), map_ds + # ) + + assert_equal( + map_ds.rename_vars({"cxy": "cnew"}).rename_vars({"cnew": "cxy"}), map_ds + ) + + assert_equal(map_da._from_temp_dataset(map_da._to_temp_dataset()), map_da) + assert_equal(map_da.astype(map_da.dtype), map_da) + assert_equal(map_da.transpose("y", "x", transpose_coords=False).cxy, map_da.cxy) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index f957316d8ac..b9b719e8af9 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -80,7 +80,7 @@ def test_repr_multiindex(self): assert expected == repr(self.mda) @pytest.mark.skipif( - LooseVersion(np.__version__) < "1.15", + LooseVersion(np.__version__) < "1.16", reason="old versions of numpy have different printing behavior", ) def test_repr_multiindex_long(self): @@ -752,12 +752,19 @@ def test_chunk(self): blocked = unblocked.chunk() assert blocked.chunks == ((3,), (4,)) + first_dask_name = blocked.data.name blocked = unblocked.chunk(chunks=((2, 1), (2, 2))) assert blocked.chunks == ((2, 1), (2, 2)) + assert blocked.data.name != first_dask_name blocked = unblocked.chunk(chunks=(3, 3)) assert blocked.chunks == ((3,), (3, 1)) + assert blocked.data.name != first_dask_name + + # name doesn't change when rechunking by same amount + # this fails if ReprObject doesn't have __dask_tokenize__ defined + assert unblocked.chunk(2).data.name == unblocked.chunk(2).data.name assert blocked.load().chunks is None @@ -1530,6 +1537,11 @@ def test_swap_dims(self): actual = array.swap_dims({"x": "y"}) assert_identical(expected, actual) + array = DataArray(np.random.randn(3), {"x": list("abc")}, "x") + expected = DataArray(array.values, {"x": ("y", list("abc"))}, dims="y") + actual = array.swap_dims({"x": "y"}) + assert_identical(expected, actual) + def test_expand_dims_error(self): array = DataArray( np.random.randn(3, 4), @@ -3961,6 +3973,43 @@ def test_dot(self): with pytest.raises(TypeError): da.dot(dm.values) + def test_dot_align_coords(self): + # GH 3694 + + x = np.linspace(-3, 3, 6) + y = np.linspace(-3, 3, 5) + z_a = range(4) + da_vals = np.arange(6 * 5 * 4).reshape((6, 5, 4)) + da = DataArray(da_vals, coords=[x, y, z_a], dims=["x", "y", "z"]) + + z_m = range(2, 6) + dm_vals = range(4) + dm = DataArray(dm_vals, coords=[z_m], dims=["z"]) + + with xr.set_options(arithmetic_join="exact"): + with raises_regex(ValueError, "indexes along dimension"): + da.dot(dm) + + da_aligned, dm_aligned = xr.align(da, dm, join="inner") + + # nd dot 1d + actual = da.dot(dm) + expected_vals = np.tensordot(da_aligned.values, dm_aligned.values, [2, 0]) + expected = DataArray(expected_vals, coords=[x, da_aligned.y], dims=["x", "y"]) + assert_equal(expected, actual) + + # multiple shared dims + dm_vals = np.arange(20 * 5 * 4).reshape((20, 5, 4)) + j = np.linspace(-3, 3, 20) + dm = DataArray(dm_vals, coords=[j, y, z_m], dims=["j", "y", "z"]) + da_aligned, dm_aligned = xr.align(da, dm, join="inner") + actual = da.dot(dm) + expected_vals = np.tensordot( + da_aligned.values, dm_aligned.values, axes=([1, 2], [1, 2]) + ) + expected = DataArray(expected_vals, coords=[x, j], dims=["x", "j"]) + assert_equal(expected, actual) + def test_matmul(self): # copied from above (could make a fixture) @@ -3974,6 +4023,24 @@ def test_matmul(self): expected = da.dot(da) assert_identical(result, expected) + def test_matmul_align_coords(self): + # GH 3694 + + x_a = np.arange(6) + x_b = np.arange(2, 8) + da_vals = np.arange(6) + da_a = DataArray(da_vals, coords=[x_a], dims=["x"]) + da_b = DataArray(da_vals, coords=[x_b], dims=["x"]) + + # only test arithmetic_join="inner" (=default) + result = da_a @ da_b + expected = da_a.dot(da_b) + assert_identical(result, expected) + + with xr.set_options(arithmetic_join="exact"): + with raises_regex(ValueError, "indexes along dimension"): + da_a @ da_b + def test_binary_op_propagate_indexes(self): # regression test for GH2227 self.dv["x"] = np.arange(self.dv.sizes["x"]) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 7db1911621b..4e51e229b29 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -936,19 +936,35 @@ def test_chunk(self): expected_chunks = {"dim1": (8,), "dim2": (9,), "dim3": (10,)} assert reblocked.chunks == expected_chunks + def get_dask_names(ds): + return {k: v.data.name for k, v in ds.items()} + + orig_dask_names = get_dask_names(reblocked) + reblocked = data.chunk({"time": 5, "dim1": 5, "dim2": 5, "dim3": 5}) # time is not a dim in any of the data_vars, so it # doesn't get chunked expected_chunks = {"dim1": (5, 3), "dim2": (5, 4), "dim3": (5, 5)} assert reblocked.chunks == expected_chunks + # make sure dask names change when rechunking by different amounts + # regression test for GH3350 + new_dask_names = get_dask_names(reblocked) + for k, v in new_dask_names.items(): + assert v != orig_dask_names[k] + reblocked = data.chunk(expected_chunks) assert reblocked.chunks == expected_chunks # reblock on already blocked data + orig_dask_names = get_dask_names(reblocked) reblocked = reblocked.chunk(expected_chunks) + new_dask_names = get_dask_names(reblocked) assert reblocked.chunks == expected_chunks assert_identical(reblocked, data) + # recuhnking with same chunk sizes should not change names + for k, v in new_dask_names.items(): + assert v == orig_dask_names[k] with raises_regex(ValueError, "some chunks"): data.chunk({"foo": 10}) @@ -1392,6 +1408,56 @@ def test_sel_dataarray_mindex(self): ) ) + def test_sel_categorical(self): + ind = pd.Series(["foo", "bar"], dtype="category") + df = pd.DataFrame({"ind": ind, "values": [1, 2]}) + ds = df.set_index("ind").to_xarray() + actual = ds.sel(ind="bar") + expected = ds.isel(ind=1) + assert_identical(expected, actual) + + def test_sel_categorical_error(self): + ind = pd.Series(["foo", "bar"], dtype="category") + df = pd.DataFrame({"ind": ind, "values": [1, 2]}) + ds = df.set_index("ind").to_xarray() + with pytest.raises(ValueError): + ds.sel(ind="bar", method="nearest") + with pytest.raises(ValueError): + ds.sel(ind="bar", tolerance="nearest") + + def test_categorical_index(self): + cat = pd.CategoricalIndex( + ["foo", "bar", "foo"], + categories=["foo", "bar", "baz", "qux", "quux", "corge"], + ) + ds = xr.Dataset( + {"var": ("cat", np.arange(3))}, + coords={"cat": ("cat", cat), "c": ("cat", [0, 1, 1])}, + ) + # test slice + actual = ds.sel(cat="foo") + expected = ds.isel(cat=[0, 2]) + assert_identical(expected, actual) + # make sure the conversion to the array works + actual = ds.sel(cat="foo")["cat"].values + assert (actual == np.array(["foo", "foo"])).all() + + ds = ds.set_index(index=["cat", "c"]) + actual = ds.unstack("index") + assert actual["var"].shape == (2, 2) + + def test_categorical_reindex(self): + cat = pd.CategoricalIndex( + ["foo", "bar", "baz"], + categories=["foo", "bar", "baz", "qux", "quux", "corge"], + ) + ds = xr.Dataset( + {"var": ("cat", np.arange(3))}, + coords={"cat": ("cat", cat), "c": ("cat", [0, 1, 2])}, + ) + actual = ds.reindex(cat=["foo"])["cat"].values + assert (actual == np.array(["foo"])).all() + def test_sel_drop(self): data = Dataset({"foo": ("x", [1, 2, 3])}, {"x": [0, 1, 2]}) expected = Dataset({"foo": 1}) @@ -2151,6 +2217,10 @@ def test_drop_variables(self): actual = data.drop(["time", "not_found_here"], errors="ignore") assert_identical(expected, actual) + with pytest.warns(PendingDeprecationWarning): + actual = data.drop({"time", "not_found_here"}, errors="ignore") + assert_identical(expected, actual) + def test_drop_index_labels(self): data = Dataset({"A": (["x", "y"], np.random.randn(2, 3)), "x": ["a", "b"]}) @@ -2444,6 +2514,9 @@ def test_rename_dims(self): with pytest.raises(ValueError): original.rename_dims(dims_dict_bad) + with pytest.raises(ValueError): + original.rename_dims({"x": "z"}) + def test_rename_vars(self): original = Dataset({"x": ("x", [0, 1, 2]), "y": ("x", [10, 11, 12]), "z": 42}) expected = Dataset( @@ -2461,6 +2534,14 @@ def test_rename_vars(self): with pytest.raises(ValueError): original.rename_vars(names_dict_bad) + def test_rename_multiindex(self): + mindex = pd.MultiIndex.from_tuples( + [([1, 2]), ([3, 4])], names=["level0", "level1"] + ) + data = Dataset({}, {"x": mindex}) + with raises_regex(ValueError, "conflicting MultiIndex"): + data.rename({"x": "level0"}) + @requires_cftime def test_rename_does_not_change_CFTimeIndex_type(self): # make sure CFTimeIndex is not converted to DatetimeIndex #3522 @@ -2525,6 +2606,12 @@ def test_swap_dims(self): with raises_regex(ValueError, "replacement dimension"): original.swap_dims({"x": "z"}) + expected = Dataset( + {"y": ("u", list("abc")), "z": 42}, coords={"x": ("u", [1, 2, 3])} + ) + actual = original.swap_dims({"x": "u"}) + assert_identical(expected, actual) + def test_expand_dims_error(self): original = Dataset( { @@ -3828,6 +3915,21 @@ def test_to_and_from_dataframe(self): expected = pd.DataFrame([[]], index=idx) assert expected.equals(actual), (expected, actual) + def test_from_dataframe_categorical(self): + cat = pd.CategoricalDtype( + categories=["foo", "bar", "baz", "qux", "quux", "corge"] + ) + i1 = pd.Series(["foo", "bar", "foo"], dtype=cat) + i2 = pd.Series(["bar", "bar", "baz"], dtype=cat) + + df = pd.DataFrame({"i1": i1, "i2": i2, "values": [1, 2, 3]}) + ds = df.set_index("i1").to_xarray() + assert len(ds["i1"]) == 3 + + ds = df.set_index(["i1", "i2"]).to_xarray() + assert len(ds["i1"]) == 2 + assert len(ds["i2"]) == 2 + @requires_sparse def test_from_dataframe_sparse(self): import sparse diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index aee7bbd6b11..f4f11473e48 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -1,3 +1,4 @@ +import datetime as dt import warnings from textwrap import dedent @@ -16,8 +17,12 @@ gradient, last, mean, + np_timedelta64_to_float, + pd_timedelta_to_float, + py_timedelta_to_float, rolling_window, stack, + timedelta_to_numeric, where, ) from xarray.core.pycompat import dask_array_type @@ -672,13 +677,15 @@ def test_datetime_to_numeric_datetime64(): @requires_cftime def test_datetime_to_numeric_cftime(): - times = cftime_range("2000", periods=5, freq="7D").values - result = duck_array_ops.datetime_to_numeric(times, datetime_unit="h") + times = cftime_range("2000", periods=5, freq="7D", calendar="standard").values + result = duck_array_ops.datetime_to_numeric(times, datetime_unit="h", dtype=int) expected = 24 * np.arange(0, 35, 7) np.testing.assert_array_equal(result, expected) offset = times[1] - result = duck_array_ops.datetime_to_numeric(times, offset=offset, datetime_unit="h") + result = duck_array_ops.datetime_to_numeric( + times, offset=offset, datetime_unit="h", dtype=int + ) expected = 24 * np.arange(-7, 28, 7) np.testing.assert_array_equal(result, expected) @@ -686,3 +693,70 @@ def test_datetime_to_numeric_cftime(): result = duck_array_ops.datetime_to_numeric(times, datetime_unit="h", dtype=dtype) expected = 24 * np.arange(0, 35, 7).astype(dtype) np.testing.assert_array_equal(result, expected) + + +@requires_cftime +def test_datetime_to_numeric_potential_overflow(): + import cftime + + times = pd.date_range("2000", periods=5, freq="7D").values.astype("datetime64[us]") + cftimes = cftime_range( + "2000", periods=5, freq="7D", calendar="proleptic_gregorian" + ).values + + offset = np.datetime64("0001-01-01") + cfoffset = cftime.DatetimeProlepticGregorian(1, 1, 1) + + result = duck_array_ops.datetime_to_numeric( + times, offset=offset, datetime_unit="D", dtype=int + ) + cfresult = duck_array_ops.datetime_to_numeric( + cftimes, offset=cfoffset, datetime_unit="D", dtype=int + ) + + expected = 730119 + np.arange(0, 35, 7) + + np.testing.assert_array_equal(result, expected) + np.testing.assert_array_equal(cfresult, expected) + + +def test_py_timedelta_to_float(): + assert py_timedelta_to_float(dt.timedelta(days=1), "ns") == 86400 * 1e9 + assert py_timedelta_to_float(dt.timedelta(days=1e6), "ps") == 86400 * 1e18 + assert py_timedelta_to_float(dt.timedelta(days=1e6), "ns") == 86400 * 1e15 + assert py_timedelta_to_float(dt.timedelta(days=1e6), "us") == 86400 * 1e12 + assert py_timedelta_to_float(dt.timedelta(days=1e6), "ms") == 86400 * 1e9 + assert py_timedelta_to_float(dt.timedelta(days=1e6), "s") == 86400 * 1e6 + assert py_timedelta_to_float(dt.timedelta(days=1e6), "D") == 1e6 + + +@pytest.mark.parametrize( + "td, expected", + ([np.timedelta64(1, "D"), 86400 * 1e9], [np.timedelta64(1, "ns"), 1.0]), +) +def test_np_timedelta64_to_float(td, expected): + out = np_timedelta64_to_float(td, datetime_unit="ns") + np.testing.assert_allclose(out, expected) + assert isinstance(out, float) + + out = np_timedelta64_to_float(np.atleast_1d(td), datetime_unit="ns") + np.testing.assert_allclose(out, expected) + + +@pytest.mark.parametrize( + "td, expected", ([pd.Timedelta(1, "D"), 86400 * 1e9], [pd.Timedelta(1, "ns"), 1.0]) +) +def test_pd_timedelta_to_float(td, expected): + out = pd_timedelta_to_float(td, datetime_unit="ns") + np.testing.assert_allclose(out, expected) + assert isinstance(out, float) + + +@pytest.mark.parametrize( + "td", [dt.timedelta(days=1), np.timedelta64(1, "D"), pd.Timedelta(1, "D"), "1 day"] +) +def test_timedelta_to_numeric(td): + # Scalar input + out = timedelta_to_numeric(td, "ns") + np.testing.assert_allclose(out, 86400 * 1e9) + assert isinstance(out, float) diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 9a1f0bbd975..61ecf46b79b 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd +import pytest import xarray as xr from xarray.core import formatting @@ -275,6 +276,44 @@ def test_diff_array_repr(self): except AssertionError: assert actual == expected.replace(", dtype=int64", "") + @pytest.mark.filterwarnings("error") + def test_diff_attrs_repr_with_array(self): + attrs_a = {"attr": np.array([0, 1])} + + attrs_b = {"attr": 1} + expected = dedent( + """\ + Differing attributes: + L attr: [0 1] + R attr: 1 + """ + ).strip() + actual = formatting.diff_attrs_repr(attrs_a, attrs_b, "equals") + assert expected == actual + + attrs_b = {"attr": np.array([-3, 5])} + expected = dedent( + """\ + Differing attributes: + L attr: [0 1] + R attr: [-3 5] + """ + ).strip() + actual = formatting.diff_attrs_repr(attrs_a, attrs_b, "equals") + assert expected == actual + + # should not raise a warning + attrs_b = {"attr": np.array([0, 1, 2])} + expected = dedent( + """\ + Differing attributes: + L attr: [0 1] + R attr: [0 1 2] + """ + ).strip() + actual = formatting.diff_attrs_repr(attrs_a, attrs_b, "equals") + assert expected == actual + def test_diff_dataset_repr(self): ds_a = xr.Dataset( data_vars={ diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index b93325d7eab..e3af8b5873a 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -662,3 +662,10 @@ def test_datetime_interp_noerror(): coords={"time": pd.date_range("01-01-2001", periods=50, freq="H")}, ) a.interp(x=xi, time=xi.time) # should not raise an error + + +@requires_cftime +def test_3641(): + times = xr.cftime_range("0001", periods=3, freq="500Y") + da = xr.DataArray(range(3), dims=["time"], coords=[times]) + da.interp(time=["0002-05-01"]) diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py index c1e6c7a5ce8..6c8f3f65657 100644 --- a/xarray/tests/test_merge.py +++ b/xarray/tests/test_merge.py @@ -3,6 +3,7 @@ import xarray as xr from xarray.core import dtypes, merge +from xarray.testing import assert_identical from . import raises_regex from .test_dataset import create_test_data @@ -253,3 +254,9 @@ def test_merge_no_conflicts(self): with pytest.raises(xr.MergeError): ds3 = xr.Dataset({"a": ("y", [2, 3]), "y": [1, 2]}) ds1.merge(ds3, compat="no_conflicts") + + def test_merge_dataarray(self): + ds = xr.Dataset({"a": 0}) + da = xr.DataArray(data=1, name="b") + + assert_identical(ds.merge(da), xr.merge([ds, da])) diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index 1cd0319a9a5..35c71c2854c 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -14,13 +14,16 @@ ) from xarray.core.pycompat import dask_array_type from xarray.tests import ( + assert_allclose, assert_array_equal, assert_equal, raises_regex, requires_bottleneck, + requires_cftime, requires_dask, requires_scipy, ) +from xarray.tests.test_cftime_offsets import _CFTIME_CALENDARS @pytest.fixture @@ -28,6 +31,18 @@ def da(): return xr.DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time") +@pytest.fixture +def cf_da(): + def _cf_da(calendar, freq="1D"): + times = xr.cftime_range( + start="1970-01-01", freq=freq, periods=10, calendar=calendar + ) + values = np.arange(10) + return xr.DataArray(values, dims=("time",), coords={"time": times}) + + return _cf_da + + @pytest.fixture def ds(): ds = xr.Dataset() @@ -472,6 +487,42 @@ def test_interpolate_na_nan_block_lengths(y, lengths): assert_equal(actual, expected) +@requires_cftime +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) +def test_get_clean_interp_index_cf_calendar(cf_da, calendar): + """The index for CFTimeIndex is in units of days. This means that if two series using a 360 and 365 days + calendar each have a trend of .01C/year, the linear regression coefficients will be different because they + have different number of days. + + Another option would be to have an index in units of years, but this would likely create other difficulties. + """ + i = get_clean_interp_index(cf_da(calendar), dim="time") + np.testing.assert_array_equal(i, np.arange(10) * 1e9 * 86400) + + +@requires_cftime +@pytest.mark.parametrize( + ("calendar", "freq"), zip(["gregorian", "proleptic_gregorian"], ["1D", "1M", "1Y"]) +) +def test_get_clean_interp_index_dt(cf_da, calendar, freq): + """In the gregorian case, the index should be proportional to normal datetimes.""" + g = cf_da(calendar, freq=freq) + g["stime"] = xr.Variable(data=g.time.to_index().to_datetimeindex(), dims=("time",)) + + gi = get_clean_interp_index(g, "time") + si = get_clean_interp_index(g, "time", use_coordinate="stime") + np.testing.assert_array_equal(gi, si) + + +def test_get_clean_interp_index_potential_overflow(): + da = xr.DataArray( + [0, 1, 2], + dims=("time",), + coords={"time": xr.cftime_range("0000-01-01", periods=3, calendar="360_day")}, + ) + get_clean_interp_index(da, "time") + + @pytest.fixture def da_time(): return xr.DataArray( @@ -490,7 +541,7 @@ def test_interpolate_na_max_gap_errors(da_time): da_time.interpolate_na("t", max_gap=(1,)) da_time["t"] = pd.date_range("2001-01-01", freq="H", periods=11) - with raises_regex(TypeError, "Underlying index is"): + with raises_regex(TypeError, "Expected value of type str"): da_time.interpolate_na("t", max_gap=1) with raises_regex(TypeError, "Expected integer or floating point"): @@ -501,10 +552,7 @@ def test_interpolate_na_max_gap_errors(da_time): @requires_bottleneck -@pytest.mark.parametrize( - "time_range_func", - [pd.date_range, pytest.param(xr.cftime_range, marks=pytest.mark.xfail)], -) +@pytest.mark.parametrize("time_range_func", [pd.date_range, xr.cftime_range]) @pytest.mark.parametrize("transform", [lambda x: x, lambda x: x.to_dataset(name="a")]) @pytest.mark.parametrize( "max_gap", ["3H", np.timedelta64(3, "h"), pd.to_timedelta("3H")] @@ -517,7 +565,7 @@ def test_interpolate_na_max_gap_time_specifier( da_time.copy(data=[np.nan, 1, 2, 3, 4, 5, np.nan, np.nan, np.nan, np.nan, 10]) ) actual = transform(da_time).interpolate_na("t", max_gap=max_gap) - assert_equal(actual, expected) + assert_allclose(actual, expected) @requires_bottleneck diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 330c975b999..c9ed3706c58 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -15,7 +15,6 @@ _build_discrete_cmap, _color_palette, _determine_cmap_params, - import_seaborn, label_from_attrs, ) @@ -63,6 +62,15 @@ def substring_in_axes(substring, ax): return False +def substring_not_in_axes(substring, ax): + """ + Return True if a substring is not found anywhere in an axes + """ + alltxt = {t.get_text() for t in ax.findobj(mpl.text.Text)} + check = [(substring not in txt) for txt in alltxt] + return all(check) + + def easy_array(shape, start=0, stop=1): """ Make an array with desired shape using np.linspace @@ -483,9 +491,25 @@ def test_convenient_facetgrid_4d(self): d.plot(x="x", y="y", col="columns", ax=plt.gca()) def test_coord_with_interval(self): + """Test line plot with intervals.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot() + def test_coord_with_interval_x(self): + """Test line plot with intervals explicitly on x axis.""" + bins = [-1, 0, 1, 2] + self.darray.groupby_bins("dim_0", bins).mean(...).plot(x="dim_0_bins") + + def test_coord_with_interval_y(self): + """Test line plot with intervals explicitly on y axis.""" + bins = [-1, 0, 1, 2] + self.darray.groupby_bins("dim_0", bins).mean(...).plot(y="dim_0_bins") + + def test_coord_with_interval_xy(self): + """Test line plot with intervals on both x and y axes.""" + bins = [-1, 0, 1, 2] + self.darray.groupby_bins("dim_0", bins).mean(...).dim_0_bins.plot() + class TestPlot1D(PlotTestCase): @pytest.fixture(autouse=True) @@ -568,10 +592,23 @@ def test_step(self): self.darray[0, 0].plot.step() def test_coord_with_interval_step(self): + """Test step plot with intervals.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot.step() assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + def test_coord_with_interval_step_x(self): + """Test step plot with intervals explicitly on x axis.""" + bins = [-1, 0, 1, 2] + self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(x="dim_0_bins") + assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + + def test_coord_with_interval_step_y(self): + """Test step plot with intervals explicitly on y axis.""" + bins = [-1, 0, 1, 2] + self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(y="dim_0_bins") + assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + class TestPlotHistogram(PlotTestCase): @pytest.fixture(autouse=True) @@ -1842,6 +1879,18 @@ def test_default_labels(self): for label, ax in zip(self.darray.coords["col"].values, g.axes[0, :]): assert substring_in_axes(label, ax) + # ensure that row & col labels can be changed + g.set_titles("abc={value}") + for label, ax in zip(self.darray.coords["row"].values, g.axes[:, -1]): + assert substring_in_axes(f"abc={label}", ax) + # previous labels were "row=row0" etc. + assert substring_not_in_axes("row=", ax) + + for label, ax in zip(self.darray.coords["col"].values, g.axes[0, :]): + assert substring_in_axes(f"abc={label}", ax) + # previous labels were "col=row0" etc. + assert substring_not_in_axes("col=", ax) + @pytest.mark.filterwarnings("ignore:tight_layout cannot") class TestFacetedLinePlotsLegend(PlotTestCase): @@ -2134,22 +2183,6 @@ def test_ncaxis_notinstalled_line_plot(self): self.darray.plot.line() -@requires_seaborn -def test_import_seaborn_no_warning(): - # GH1633 - with pytest.warns(None) as record: - import_seaborn() - assert len(record) == 0 - - -@requires_matplotlib -def test_plot_seaborn_no_import_warning(): - # GH1633 - with pytest.warns(None) as record: - _color_palette("Blues", 4) - assert len(record) == 0 - - test_da_list = [ DataArray(easy_array((10,))), DataArray(easy_array((10, 3))), diff --git a/xarray/tests/test_sparse.py b/xarray/tests/test_sparse.py index a02fef2faeb..21a212c29b3 100644 --- a/xarray/tests/test_sparse.py +++ b/xarray/tests/test_sparse.py @@ -873,3 +873,16 @@ def test_dask_token(): t5 = dask.base.tokenize(ac + 1) assert t4 != t5 assert isinstance(ac.data._meta, sparse.COO) + + +@requires_dask +def test_apply_ufunc_meta_to_blockwise(): + da = xr.DataArray(np.zeros((2, 3)), dims=["x", "y"]).chunk({"x": 2, "y": 1}) + sparse_meta = sparse.COO.from_numpy(np.zeros((0, 0))) + + # if dask computed meta, it would be np.ndarray + expected = xr.apply_ufunc( + lambda x: x, da, dask="parallelized", output_dtypes=[da.dtype], meta=sparse_meta + ).data._meta + + assert_sparse_equal(expected, sparse_meta) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 0be6f8af464..75e743c3455 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1,4 +1,5 @@ import operator +from distutils.version import LooseVersion import numpy as np import pandas as pd @@ -8,13 +9,18 @@ from xarray.core import formatting from xarray.core.npcompat import IS_NEP18_ACTIVE +from .test_variable import VariableSubclassobjects + pint = pytest.importorskip("pint") DimensionalityError = pint.errors.DimensionalityError -unit_registry = pint.UnitRegistry() +# make sure scalars are converted to 0d arrays so quantities can +# always be treated like ndarrays +unit_registry = pint.UnitRegistry(force_ndarray=True) Quantity = unit_registry.Quantity + pytestmark = [ pytest.mark.skipif( not IS_NEP18_ACTIVE, reason="NUMPY_EXPERIMENTAL_ARRAY_FUNCTION is not enabled" @@ -28,10 +34,31 @@ ] +def is_compatible(unit1, unit2): + def dimensionality(obj): + if isinstance(obj, (unit_registry.Quantity, unit_registry.Unit)): + unit_like = obj + else: + unit_like = unit_registry.dimensionless + + return unit_like.dimensionality + + return dimensionality(unit1) == dimensionality(unit2) + + +def compatible_mappings(first, second): + return { + key: is_compatible(unit1, unit2) + for key, (unit1, unit2) in merge_mappings(first, second) + } + + def array_extract_units(obj): - raw = obj.data if hasattr(obj, "data") else obj + if isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)): + obj = obj.data + try: - return raw.units + return obj.units except AttributeError: return None @@ -111,8 +138,12 @@ def extract_units(obj): } units = {**vars_units, **coords_units} + elif isinstance(obj, xr.Variable): + vars_units = {None: array_extract_units(obj.data)} + + units = {**vars_units} elif isinstance(obj, Quantity): - vars_units = {"": array_extract_units(obj)} + vars_units = {None: array_extract_units(obj)} units = {**vars_units} else: @@ -146,6 +177,9 @@ def strip_units(obj): new_obj = xr.DataArray( name=strip_units(obj.name), data=data, coords=coords, dims=obj.dims ) + elif isinstance(obj, xr.Variable): + data = array_strip_units(obj.data) + new_obj = obj.copy(data=data) elif isinstance(obj, unit_registry.Quantity): new_obj = obj.magnitude elif isinstance(obj, (list, tuple)): @@ -157,8 +191,9 @@ def strip_units(obj): def attach_units(obj, units): - if not isinstance(obj, (xr.DataArray, xr.Dataset)): - return array_attach_units(obj, units.get("data", 1)) + if not isinstance(obj, (xr.DataArray, xr.Dataset, xr.Variable)): + units = units.get("data", None) or units.get(None, None) or 1 + return array_attach_units(obj, units) if isinstance(obj, xr.Dataset): data_vars = { @@ -170,7 +205,7 @@ def attach_units(obj, units): } new_obj = xr.Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs) - else: + elif isinstance(obj, xr.DataArray): # try the array name, "data" and None, then fall back to dimensionless data_units = ( units.get(obj.name, None) @@ -196,6 +231,11 @@ def attach_units(obj, units): new_obj = xr.DataArray( name=obj.name, data=data, coords=coords, attrs=attrs, dims=dims ) + else: + data_units = units.get("data", None) or units.get(None, None) or 1 + + data = array_attach_units(obj.data, data_units) + new_obj = obj.copy(data=data) return new_obj @@ -203,21 +243,25 @@ def attach_units(obj, units): def convert_units(obj, to): if isinstance(obj, xr.Dataset): data_vars = { - name: convert_units(array, to) for name, array in obj.data_vars.items() + name: convert_units(array.variable, {None: to.get(name)}) + for name, array in obj.data_vars.items() + } + coords = { + name: convert_units(array.variable, {None: to.get(name)}) + for name, array in obj.coords.items() } - coords = {name: convert_units(array, to) for name, array in obj.coords.items()} new_obj = xr.Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs) elif isinstance(obj, xr.DataArray): name = obj.name new_units = ( - to.get(name, None) or to.get("data", None) or to.get(None, None) or 1 + to.get(name, None) or to.get("data", None) or to.get(None, None) or None ) - data = convert_units(obj.data, {None: new_units}) + data = convert_units(obj.variable, {None: new_units}) coords = { - name: (array.dims, convert_units(array.data, to)) + name: (array.dims, convert_units(array.variable, {None: to.get(name)})) for name, array in obj.coords.items() if name != obj.name } @@ -225,6 +269,9 @@ def convert_units(obj, to): new_obj = xr.DataArray( name=name, data=data, coords=coords, attrs=obj.attrs, dims=obj.dims ) + elif isinstance(obj, xr.Variable): + new_data = convert_units(obj.data, to) + new_obj = obj.copy(data=new_data) elif isinstance(obj, unit_registry.Quantity): units = to.get(None) new_obj = obj.to(units) if units is not None else obj @@ -234,6 +281,10 @@ def convert_units(obj, to): return new_obj +def assert_units_equal(a, b): + assert extract_units(a) == extract_units(b) + + def assert_equal_with_units(a, b): # works like xr.testing.assert_equal, but also explicitly checks units # so, it is more like assert_identical @@ -273,7 +324,27 @@ def dtype(request): return request.param +def merge_mappings(*mappings): + for key in set(mappings[0]).intersection(*mappings[1:]): + yield key, tuple(m[key] for m in mappings) + + +def merge_args(default_args, new_args): + from itertools import zip_longest + + fill_value = object() + return [ + second if second is not fill_value else first + for first, second in zip_longest(default_args, new_args, fillvalue=fill_value) + ] + + class method: + """ wrapper class to help with passing methods via parametrize + + This is works a bit similar to using `partial(Class.method, arg, kwarg)` + """ + def __init__(self, name, *args, **kwargs): self.name = name self.args = args @@ -283,7 +354,7 @@ def __call__(self, obj, *args, **kwargs): from collections.abc import Callable from functools import partial - all_args = list(self.args) + list(args) + all_args = merge_args(self.args, args) all_kwargs = {**self.kwargs, **kwargs} func = getattr(obj, self.name, None) @@ -292,7 +363,7 @@ def __call__(self, obj, *args, **kwargs): if not isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)): numpy_func = getattr(np, self.name) func = partial(numpy_func, obj) - # remove typical xr args like "dim" + # remove typical xarray args like "dim" exclude_kwargs = ("dim", "dims") all_kwargs = { key: value @@ -309,12 +380,21 @@ def __repr__(self): class function: - def __init__(self, name_or_function, *args, **kwargs): + """ wrapper class for numpy functions + + Same as method, but the name is used for referencing numpy functions + """ + + def __init__(self, name_or_function, *args, function_label=None, **kwargs): if callable(name_or_function): - self.name = name_or_function.__name__ + self.name = ( + function_label + if function_label is not None + else name_or_function.__name__ + ) self.func = name_or_function else: - self.name = name_or_function + self.name = name_or_function if function_label is None else function_label self.func = getattr(np, name_or_function) if self.func is None: raise AttributeError( @@ -325,7 +405,7 @@ def __init__(self, name_or_function, *args, **kwargs): self.kwargs = kwargs def __call__(self, *args, **kwargs): - all_args = list(self.args) + list(args) + all_args = merge_args(self.args, args) all_kwargs = {**self.kwargs, **kwargs} return self.func(*all_args, **all_kwargs) @@ -334,6 +414,7 @@ def __repr__(self): return f"function_{self.name}" +@pytest.mark.xfail(reason="test bug: apply_ufunc should not be called that way") def test_apply_ufunc_dataarray(dtype): func = function( xr.apply_ufunc, np.mean, input_core_dims=[["x"]], kwargs={"axis": -1} @@ -344,14 +425,35 @@ def test_apply_ufunc_dataarray(dtype): data_array = xr.DataArray(data=array, dims="x", coords={"x": x}) expected = attach_units(func(strip_units(data_array)), extract_units(data_array)) - result = func(data_array) + actual = func(data_array) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) -@pytest.mark.xfail( - reason="pint does not implement `np.result_type` and align strips units" -) +@pytest.mark.xfail(reason="test bug: apply_ufunc should not be called that way") +def test_apply_ufunc_dataset(dtype): + func = function( + xr.apply_ufunc, np.mean, input_core_dims=[["x"]], kwargs={"axis": -1} + ) + + array1 = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * unit_registry.m + array2 = np.linspace(0, 10, 5).astype(dtype) * unit_registry.m + + x = np.arange(5) * unit_registry.s + y = np.arange(10) * unit_registry.m + + ds = xr.Dataset( + data_vars={"a": (("x", "y"), array1), "b": ("x", array2)}, + coords={"x": x, "y": y}, + ) + + expected = attach_units(func(strip_units(ds)), extract_units(ds)) + actual = func(ds) + + assert_equal_with_units(expected, actual) + + +@pytest.mark.xfail(reason="blocked by `reindex` / `where`") @pytest.mark.parametrize( "unit,error", ( @@ -378,9 +480,9 @@ def test_align_dataarray(fill_value, variant, unit, error, dtype): original_unit = unit_registry.m variants = { - "data": (unit, original_unit, original_unit), - "dims": (original_unit, unit, original_unit), - "coords": (original_unit, original_unit, unit), + "data": (unit, 1, 1), + "dims": (original_unit, unit, 1), + "coords": (original_unit, 1, unit), } data_unit, dim_unit, coord_unit = variants.get(variant) @@ -410,32 +512,27 @@ def test_align_dataarray(fill_value, variant, unit, error, dtype): stripped_kwargs = { key: strip_units( - convert_units(value, {None: original_unit}) - if isinstance(value, unit_registry.Quantity) - else value + convert_units(value, {None: original_unit if data_unit != 1 else None}) ) for key, value in func.kwargs.items() } - units = extract_units(data_array1) - # FIXME: should the expected_b have the same units as data_array1 - # or data_array2? - expected_a, expected_b = tuple( - attach_units(elem, units) - for elem in func( - strip_units(data_array1), - strip_units(convert_units(data_array2, units)), - **stripped_kwargs, - ) + units_a = extract_units(data_array1) + units_b = extract_units(data_array2) + expected_a, expected_b = func( + strip_units(data_array1), + strip_units(convert_units(data_array2, units_a)), + **stripped_kwargs, ) - result_a, result_b = func(data_array1, data_array2) + expected_a = attach_units(expected_a, units_a) + expected_b = convert_units(attach_units(expected_b, units_a), units_b) - assert_equal_with_units(expected_a, result_a) - assert_equal_with_units(expected_b, result_b) + actual_a, actual_b = func(data_array1, data_array2) + assert_equal_with_units(expected_a, actual_a) + assert_equal_with_units(expected_b, actual_b) -@pytest.mark.xfail( - reason="pint does not implement `np.result_type` and align strips units" -) + +@pytest.mark.xfail(reason="blocked by `reindex` / `where`") @pytest.mark.parametrize( "unit,error", ( @@ -461,11 +558,7 @@ def test_align_dataarray(fill_value, variant, unit, error, dtype): def test_align_dataset(fill_value, unit, variant, error, dtype): original_unit = unit_registry.m - variants = { - "data": (unit, original_unit, original_unit), - "dims": (original_unit, unit, original_unit), - "coords": (original_unit, original_unit, unit), - } + variants = {"data": (unit, 1, 1), "dims": (1, unit, 1), "coords": (1, 1, unit)} data_unit, dim_unit, coord_unit = variants.get(variant) array1 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * original_unit @@ -497,24 +590,22 @@ def test_align_dataset(fill_value, unit, variant, error, dtype): stripped_kwargs = { key: strip_units( - convert_units(value, {None: original_unit}) - if isinstance(value, unit_registry.Quantity) - else value + convert_units(value, {None: original_unit if data_unit != 1 else None}) ) for key, value in func.kwargs.items() } - units = extract_units(ds1) - # FIXME: should the expected_b have the same units as ds1 or ds2? - expected_a, expected_b = tuple( - attach_units(elem, units) - for elem in func( - strip_units(ds1), strip_units(convert_units(ds2, units)), **stripped_kwargs - ) + units_a = extract_units(ds1) + units_b = extract_units(ds2) + expected_a, expected_b = func( + strip_units(ds1), strip_units(convert_units(ds2, units_a)), **stripped_kwargs ) - result_a, result_b = func(ds1, ds2) + expected_a = attach_units(expected_a, units_a) + expected_b = convert_units(attach_units(expected_b, units_a), units_b) - assert_equal_with_units(expected_a, result_a) - assert_equal_with_units(expected_b, result_b) + actual_a, actual_b = func(ds1, ds2) + + assert_equal_with_units(expected_a, actual_a) + assert_equal_with_units(expected_b, actual_b) def test_broadcast_dataarray(dtype): @@ -528,10 +619,10 @@ def test_broadcast_dataarray(dtype): attach_units(elem, extract_units(a)) for elem in xr.broadcast(strip_units(a), strip_units(b)) ) - result_a, result_b = xr.broadcast(a, b) + actual_a, actual_b = xr.broadcast(a, b) - assert_equal_with_units(expected_a, result_a) - assert_equal_with_units(expected_b, result_b) + assert_equal_with_units(expected_a, actual_a) + assert_equal_with_units(expected_b, actual_b) def test_broadcast_dataset(dtype): @@ -543,12 +634,11 @@ def test_broadcast_dataset(dtype): (expected,) = tuple( attach_units(elem, extract_units(ds)) for elem in xr.broadcast(strip_units(ds)) ) - (result,) = xr.broadcast(ds) + (actual,) = xr.broadcast(ds) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) -@pytest.mark.xfail(reason="`combine_by_coords` strips units") @pytest.mark.parametrize( "unit,error", ( @@ -614,12 +704,11 @@ def test_combine_by_coords(variant, unit, error, dtype): ), units, ) - result = xr.combine_by_coords([ds, other]) + actual = xr.combine_by_coords([ds, other]) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) -@pytest.mark.xfail(reason="blocked by `where`") @pytest.mark.parametrize( "unit,error", ( @@ -628,7 +717,12 @@ def test_combine_by_coords(variant, unit, error, dtype): unit_registry.dimensionless, DimensionalityError, id="dimensionless" ), pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), - pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param( + unit_registry.mm, + None, + id="compatible_unit", + marks=pytest.mark.xfail(reason="wrong order of arguments to `where`"), + ), pytest.param(unit_registry.m, None, id="identical_unit"), ), ids=repr, @@ -714,12 +808,11 @@ def test_combine_nested(variant, unit, error, dtype): ), units, ) - result = func([[ds1, ds2], [ds3, ds4]]) + actual = func([[ds1, ds2], [ds3, ds4]]) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) -@pytest.mark.xfail(reason="`concat` strips units") @pytest.mark.parametrize( "unit,error", ( @@ -760,15 +853,18 @@ def test_concat_dataarray(variant, unit, error, dtype): return + units = extract_units(arr1) expected = attach_units( - xr.concat([strip_units(arr1), strip_units(arr2)], dim="x"), extract_units(arr1) + xr.concat( + [strip_units(arr1), strip_units(convert_units(arr2, units))], dim="x" + ), + units, ) - result = xr.concat([arr1, arr2], dim="x") + actual = xr.concat([arr1, arr2], dim="x") - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) -@pytest.mark.xfail(reason="`concat` strips units") @pytest.mark.parametrize( "unit,error", ( @@ -809,15 +905,17 @@ def test_concat_dataset(variant, unit, error, dtype): return + units = extract_units(ds1) expected = attach_units( - xr.concat([strip_units(ds1), strip_units(ds2)], dim="x"), extract_units(ds1) + xr.concat([strip_units(ds1), strip_units(convert_units(ds2, units))], dim="x"), + units, ) - result = xr.concat([ds1, ds2], dim="x") + actual = xr.concat([ds1, ds2], dim="x") - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) -@pytest.mark.xfail(reason="blocked by `where`") +@pytest.mark.xfail(reason="blocked by `reindex` / `where`") @pytest.mark.parametrize( "unit,error", ( @@ -902,12 +1000,12 @@ def test_merge_dataarray(variant, unit, error, dtype): func([strip_units(arr1), convert_and_strip(arr2), convert_and_strip(arr3)]), units, ) - result = func([arr1, arr2, arr3]) + actual = func([arr1, arr2, arr3]) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) -@pytest.mark.xfail(reason="blocked by `where`") +@pytest.mark.xfail(reason="blocked by `reindex` / `where`") @pytest.mark.parametrize( "unit,error", ( @@ -985,9 +1083,9 @@ def test_merge_dataset(variant, unit, error, dtype): expected = attach_units( func([strip_units(ds1), convert_and_strip(ds2), convert_and_strip(ds3)]), units ) - result = func([ds1, ds2, ds3]) + actual = func([ds1, ds2, ds3]) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize("func", (xr.zeros_like, xr.ones_like)) @@ -997,9 +1095,9 @@ def test_replication_dataarray(func, dtype): numpy_func = getattr(np, func.__name__) expected = xr.DataArray(data=numpy_func(array), dims="x") - result = func(data_array) + actual = func(data_array) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize("func", (xr.zeros_like, xr.ones_like)) @@ -1019,9 +1117,9 @@ def test_replication_dataset(func, dtype): expected = ds.copy( data={name: numpy_func(array.data) for name, array in ds.data_vars.items()} ) - result = func(ds) + actual = func(ds) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.xfail( @@ -1051,11 +1149,16 @@ def test_replication_full_like_dataarray(unit, error, dtype): if error is not None: with pytest.raises(error): xr.full_like(data_array, fill_value=fill_value) - else: - result = xr.full_like(data_array, fill_value=fill_value) - expected = np.full_like(array, fill_value=fill_value) - assert_equal_with_units(expected, result) + return + + units = {**extract_units(data_array), **{None: unit if unit != 1 else None}} + expected = attach_units( + xr.full_like(strip_units(data_array), fill_value=strip_units(fill_value)), units + ) + actual = xr.full_like(data_array, fill_value=fill_value) + + assert_equal_with_units(expected, actual) @pytest.mark.xfail( @@ -1096,18 +1199,18 @@ def test_replication_full_like_dataset(unit, error, dtype): return - expected = ds.copy( - data={ - name: np.full_like(array, fill_value=fill_value) - for name, array in ds.data_vars.items() - } + units = { + **extract_units(ds), + **{name: unit if unit != 1 else None for name in ds.data_vars}, + } + expected = attach_units( + xr.full_like(strip_units(ds), fill_value=strip_units(fill_value)), units ) - result = xr.full_like(ds, fill_value=fill_value) + actual = xr.full_like(ds, fill_value=fill_value) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) -@pytest.mark.xfail(reason="`where` strips units") @pytest.mark.parametrize( "unit,error", ( @@ -1127,30 +1230,29 @@ def test_where_dataarray(fill_value, unit, error, dtype): x = xr.DataArray(data=array, dims="x") cond = x < 5 * unit_registry.m - # FIXME: this should work without wrapping in array() - fill_value = np.array(fill_value) * unit + fill_value = fill_value * unit - if error is not None: + if error is not None and not ( + np.isnan(fill_value) and not isinstance(fill_value, Quantity) + ): with pytest.raises(error): xr.where(cond, x, fill_value) return - fill_value_ = ( - fill_value.to(unit_registry.m) - if isinstance(fill_value, unit_registry.Quantity) - and fill_value.check(unit_registry.m) - else fill_value - ) expected = attach_units( - xr.where(cond, strip_units(x), strip_units(fill_value_)), extract_units(x) + xr.where( + cond, + strip_units(x), + strip_units(convert_units(fill_value, {None: unit_registry.m})), + ), + extract_units(x), ) - result = xr.where(cond, x, fill_value) + actual = xr.where(cond, x, fill_value) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) -@pytest.mark.xfail(reason="`where` strips units") @pytest.mark.parametrize( "unit,error", ( @@ -1171,31 +1273,30 @@ def test_where_dataset(fill_value, unit, error, dtype): x = np.arange(10) * unit_registry.s ds = xr.Dataset(data_vars={"a": ("x", array1), "b": ("x", array2)}, coords={"x": x}) - cond = ds.x < 5 * unit_registry.s - # FIXME: this should work without wrapping in array() - fill_value = np.array(fill_value) * unit + cond = x < 5 * unit_registry.s + fill_value = fill_value * unit - if error is not None: + if error is not None and not ( + np.isnan(fill_value) and not isinstance(fill_value, Quantity) + ): with pytest.raises(error): xr.where(cond, ds, fill_value) return - fill_value_ = ( - fill_value.to(unit_registry.m) - if isinstance(fill_value, unit_registry.Quantity) - and fill_value.check(unit_registry.m) - else fill_value - ) expected = attach_units( - xr.where(cond, strip_units(ds), strip_units(fill_value_)), extract_units(ds) + xr.where( + cond, + strip_units(ds), + strip_units(convert_units(fill_value, {None: unit_registry.m})), + ), + extract_units(ds), ) - result = xr.where(cond, ds, fill_value) + actual = xr.where(cond, ds, fill_value) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) -@pytest.mark.xfail(reason="pint does not implement `np.einsum`") def test_dot_dataarray(dtype): array1 = ( np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) @@ -1206,126 +1307,899 @@ def test_dot_dataarray(dtype): np.linspace(10, 20, 10 * 20).reshape(10, 20).astype(dtype) * unit_registry.s ) - arr1 = xr.DataArray(data=array1, dims=("x", "y")) - arr2 = xr.DataArray(data=array2, dims=("y", "z")) + data_array = xr.DataArray(data=array1, dims=("x", "y")) + other = xr.DataArray(data=array2, dims=("y", "z")) + + expected = attach_units( + xr.dot(strip_units(data_array), strip_units(other)), {None: unit_registry.m} + ) + actual = xr.dot(data_array, other) + + assert_equal_with_units(expected, actual) + + +def delete_attrs(*to_delete): + def wrapper(cls): + for item in to_delete: + setattr(cls, item, None) + + return cls + + return wrapper + + +@delete_attrs( + "test_getitem_with_mask", + "test_getitem_with_mask_nd_indexer", + "test_index_0d_string", + "test_index_0d_datetime", + "test_index_0d_timedelta64", + "test_0d_time_data", + "test_datetime64_conversion", + "test_timedelta64_conversion", + "test_pandas_period_index", + "test_1d_math", + "test_1d_reduce", + "test_array_interface", + "test___array__", + "test_copy_index", + "test_concat_number_strings", + "test_concat_fixed_len_str", + "test_concat_mixed_dtypes", + "test_pandas_datetime64_with_tz", + "test_pandas_data", + "test_multiindex", +) +class TestVariable(VariableSubclassobjects): + @staticmethod + def cls(dims, data, *args, **kwargs): + return xr.Variable( + dims, unit_registry.Quantity(data, unit_registry.m), *args, **kwargs + ) - expected = array1.dot(array2) - result = xr.dot(arr1, arr2) + @pytest.mark.parametrize( + "func", + ( + method("all"), + method("any"), + method("argmax"), + method("argmin"), + method("argsort"), + method("cumprod"), + method("cumsum"), + method("max"), + method("mean"), + method("median"), + method("min"), + pytest.param( + method("prod"), + marks=pytest.mark.xfail(reason="not implemented by pint"), + ), + method("std"), + method("sum"), + method("var"), + ), + ids=repr, + ) + def test_aggregation(self, func, dtype): + array = np.linspace(0, 1, 10).astype(dtype) * ( + unit_registry.m if func.name != "cumprod" else unit_registry.dimensionless + ) + variable = xr.Variable("x", array) - assert_equal_with_units(expected, result) + units = extract_units(func(array)) + expected = attach_units(func(strip_units(variable)), units) + actual = func(variable) + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) -class TestDataArray: - @pytest.mark.filterwarnings("error:::pint[.*]") @pytest.mark.parametrize( - "variant", + "func", + ( + method("astype", np.float32), + method("conj"), + method("conjugate"), + method("clip", min=2, max=7), + ), + ids=repr, + ) + @pytest.mark.parametrize( + "unit,error", ( + pytest.param(1, DimensionalityError, id="no_unit"), pytest.param( - "with_dims", - marks=pytest.mark.xfail(reason="units in indexes are not supported"), + unit_registry.dimensionless, DimensionalityError, id="dimensionless" ), - pytest.param("with_coords"), - pytest.param("without_coords"), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), ), ) - def test_init(self, variant, dtype): - array = np.linspace(1, 2, 10, dtype=dtype) * unit_registry.m + def test_numpy_methods(self, func, unit, error, dtype): + array = np.linspace(0, 1, 10).astype(dtype) * unit_registry.m + variable = xr.Variable("x", array) - x = np.arange(len(array)) * unit_registry.s - y = x.to(unit_registry.ms) - - variants = { - "with_dims": {"x": x}, - "with_coords": {"y": ("x", y)}, - "without_coords": {}, + args = [ + item * unit if isinstance(item, (int, float, list)) else item + for item in func.args + ] + kwargs = { + key: value * unit if isinstance(value, (int, float, list)) else value + for key, value in func.kwargs.items() } - kwargs = {"data": array, "dims": "x", "coords": variants.get(variant)} - data_array = xr.DataArray(**kwargs) + if error is not None and func.name in ("searchsorted", "clip"): + with pytest.raises(error): + func(variable, *args, **kwargs) - assert isinstance(data_array.data, Quantity) - assert all( - { - name: isinstance(coord.data, Quantity) - for name, coord in data_array.coords.items() - }.values() + return + + converted_args = [ + strip_units(convert_units(item, {None: unit_registry.m})) for item in args + ] + converted_kwargs = { + key: strip_units(convert_units(value, {None: unit_registry.m})) + for key, value in kwargs.items() + } + + units = extract_units(func(array, *args, **kwargs)) + expected = attach_units( + func(strip_units(variable), *converted_args, **converted_kwargs), units ) + actual = func(variable, *args, **kwargs) + + assert_units_equal(expected, actual) + xr.testing.assert_allclose(expected, actual) - @pytest.mark.filterwarnings("error:::pint[.*]") @pytest.mark.parametrize( - "func", (pytest.param(str, id="str"), pytest.param(repr, id="repr")) + "func", (method("item", 5), method("searchsorted", 5)), ids=repr ) @pytest.mark.parametrize( - "variant", + "unit,error", ( + pytest.param(1, DimensionalityError, id="no_unit"), pytest.param( - "with_dims", - marks=pytest.mark.xfail(reason="units in indexes are not supported"), + unit_registry.dimensionless, DimensionalityError, id="dimensionless" ), - pytest.param("with_coords"), - pytest.param("without_coords"), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), ), ) - def test_repr(self, func, variant, dtype): - array = np.linspace(1, 2, 10, dtype=dtype) * unit_registry.m - x = np.arange(len(array)) * unit_registry.s - y = x.to(unit_registry.ms) + def test_raw_numpy_methods(self, func, unit, error, dtype): + array = np.linspace(0, 1, 10).astype(dtype) * unit_registry.m + variable = xr.Variable("x", array) - variants = { - "with_dims": {"x": x}, - "with_coords": {"y": ("x", y)}, - "without_coords": {}, + args = [ + item * unit + if isinstance(item, (int, float, list)) and func.name != "item" + else item + for item in func.args + ] + kwargs = { + key: value * unit + if isinstance(value, (int, float, list)) and func.name != "item" + else value + for key, value in func.kwargs.items() } - kwargs = {"data": array, "dims": "x", "coords": variants.get(variant)} - data_array = xr.DataArray(**kwargs) + if error is not None and func.name != "item": + with pytest.raises(error): + func(variable, *args, **kwargs) - # FIXME: this just checks that the repr does not raise - # warnings or errors, but does not check the result - func(data_array) + return + + converted_args = [ + strip_units(convert_units(item, {None: unit_registry.m})) + if func.name != "item" + else item + for item in args + ] + converted_kwargs = { + key: strip_units(convert_units(value, {None: unit_registry.m})) + if func.name != "item" + else value + for key, value in kwargs.items() + } + + units = extract_units(func(array, *args, **kwargs)) + expected = attach_units( + func(strip_units(variable), *converted_args, **converted_kwargs), units + ) + actual = func(variable, *args, **kwargs) + + assert_units_equal(expected, actual) + np.testing.assert_allclose(expected, actual) @pytest.mark.parametrize( - "func", + "func", (method("isnull"), method("notnull"), method("count")), ids=repr + ) + def test_missing_value_detection(self, func): + array = ( + np.array( + [ + [1.4, 2.3, np.nan, 7.2], + [np.nan, 9.7, np.nan, np.nan], + [2.1, np.nan, np.nan, 4.6], + [9.9, np.nan, 7.2, 9.1], + ] + ) + * unit_registry.degK + ) + variable = xr.Variable(("x", "y"), array) + + expected = func(strip_units(variable)) + actual = func(variable) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit,error", ( + pytest.param(1, DimensionalityError, id="no_unit"), pytest.param( - function("all"), - marks=pytest.mark.xfail(reason="not implemented by pint yet"), + unit_registry.dimensionless, DimensionalityError, id="dimensionless" ), - pytest.param( - function("any"), - marks=pytest.mark.xfail(reason="not implemented by pint yet"), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_missing_value_fillna(self, unit, error): + value = 10 + array = ( + np.array( + [ + [1.4, 2.3, np.nan, 7.2], + [np.nan, 9.7, np.nan, np.nan], + [2.1, np.nan, np.nan, 4.6], + [9.9, np.nan, 7.2, 9.1], + ] + ) + * unit_registry.m + ) + variable = xr.Variable(("x", "y"), array) + + fill_value = value * unit + + if error is not None: + with pytest.raises(error): + variable.fillna(value=fill_value) + + return + + expected = attach_units( + strip_units(variable).fillna( + value=fill_value.to(unit_registry.m).magnitude ), - pytest.param( - function("argmax"), - marks=pytest.mark.xfail( - reason="comparison of quantity with ndarrays in nanops not implemented" - ), + extract_units(variable), + ) + actual = variable.fillna(value=fill_value) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit",), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "convert_data", + ( + pytest.param(False, id="no_conversion"), + pytest.param(True, id="with_conversion"), + ), + ) + @pytest.mark.parametrize( + "func", + ( + method("equals"), + pytest.param( + method("identical"), + marks=pytest.mark.skip(reason="behaviour of identical is unclear"), + ), + ), + ids=repr, + ) + def test_comparisons(self, func, unit, convert_data, dtype): + array = np.linspace(0, 1, 9).astype(dtype) + quantity1 = array * unit_registry.m + variable = xr.Variable("x", quantity1) + + if convert_data and is_compatible(unit_registry.m, unit): + quantity2 = convert_units(array * unit_registry.m, {None: unit}) + else: + quantity2 = array * unit + other = xr.Variable("x", quantity2) + + expected = func( + strip_units(variable), + strip_units( + convert_units(other, extract_units(variable)) + if is_compatible(unit_registry.m, unit) + else other + ), + ) + if func.name == "identical": + expected &= extract_units(variable) == extract_units(other) + else: + expected &= all( + compatible_mappings( + extract_units(variable), extract_units(other) + ).values() + ) + + actual = func(variable, other) + + assert expected == actual + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + def test_broadcast_equals(self, unit, dtype): + base_unit = unit_registry.m + left_array = np.ones(shape=(2, 2), dtype=dtype) * base_unit + value = ( + (1 * base_unit).to(unit).magnitude if is_compatible(unit, base_unit) else 1 + ) + right_array = np.full(shape=(2,), fill_value=value, dtype=dtype) * unit + + left = xr.Variable(("x", "y"), left_array) + right = xr.Variable("x", right_array) + + units = { + **extract_units(left), + **({} if is_compatible(unit, base_unit) else {None: None}), + } + expected = strip_units(left).broadcast_equals( + strip_units(convert_units(right, units)) + ) & is_compatible(unit, base_unit) + actual = left.broadcast_equals(right) + + assert expected == actual + + @pytest.mark.parametrize( + "indices", + ( + pytest.param(4, id="single index"), + pytest.param([5, 2, 9, 1], id="multiple indices"), + ), + ) + def test_isel(self, indices, dtype): + array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.s + variable = xr.Variable("x", array) + + expected = attach_units( + strip_units(variable).isel(x=indices), extract_units(variable) + ) + actual = variable.isel(x=indices) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "func", + ( + function(lambda x, *_: +x, function_label="unary_plus"), + function(lambda x, *_: -x, function_label="unary_minus"), + function(lambda x, *_: abs(x), function_label="absolute"), + function(lambda x, y: x + y, function_label="sum"), + function(lambda x, y: y + x, function_label="commutative_sum"), + function(lambda x, y: x * y, function_label="product"), + function(lambda x, y: y * x, function_label="commutative_product"), + ), + ids=repr, + ) + def test_1d_math(self, func, unit, error, dtype): + base_unit = unit_registry.m + array = np.arange(5).astype(dtype) * base_unit + variable = xr.Variable("x", array) + + values = np.ones(5) + y = values * unit + + if error is not None and func.name in ("sum", "commutative_sum"): + with pytest.raises(error): + func(variable, y) + + return + + units = extract_units(func(array, y)) + if all(compatible_mappings(units, extract_units(y)).values()): + converted_y = convert_units(y, units) + else: + converted_y = y + + if all(compatible_mappings(units, extract_units(variable)).values()): + converted_variable = convert_units(variable, units) + else: + converted_variable = variable + + expected = attach_units( + func(strip_units(converted_variable), strip_units(converted_y)), units + ) + actual = func(variable, y) + + assert_units_equal(expected, actual) + xr.testing.assert_allclose(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "func", (method("where"), method("_getitem_with_mask")), ids=repr + ) + def test_masking(self, func, unit, error, dtype): + base_unit = unit_registry.m + array = np.linspace(0, 5, 10).astype(dtype) * base_unit + variable = xr.Variable("x", array) + cond = np.array([True, False] * 5) + + other = -1 * unit + + if error is not None: + with pytest.raises(error): + func(variable, cond, other) + + return + + expected = attach_units( + func( + strip_units(variable), + cond, + strip_units( + convert_units( + other, + {None: base_unit} + if is_compatible(base_unit, unit) + else {None: None}, + ) + ), + ), + extract_units(variable), + ) + actual = func(variable, cond, other) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + def test_squeeze(self, dtype): + shape = (2, 1, 3, 1, 1, 2) + names = list("abcdef") + array = np.ones(shape=shape) * unit_registry.m + variable = xr.Variable(names, array) + + expected = attach_units( + strip_units(variable).squeeze(), extract_units(variable) + ) + actual = variable.squeeze() + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + names = tuple(name for name, size in zip(names, shape) if shape == 1) + for name in names: + expected = attach_units( + strip_units(variable).squeeze(dim=name), extract_units(variable) + ) + actual = variable.squeeze(dim=name) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "func", + ( + method("coarsen", windows={"y": 2}, func=np.mean), + pytest.param( + method("quantile", q=[0.25, 0.75]), + marks=pytest.mark.xfail(reason="nanquantile not implemented"), + ), + pytest.param( + method("rank", dim="x"), + marks=pytest.mark.xfail(reason="rank not implemented for non-ndarray"), + ), + method("roll", {"x": 2}), + pytest.param( + method("rolling_window", "x", 3, "window"), + marks=pytest.mark.xfail(reason="converts to ndarray"), + ), + method("reduce", np.std, "x"), + method("round", 2), + method("shift", {"x": -2}), + method("transpose", "y", "x"), + ), + ids=repr, + ) + def test_computation(self, func, dtype): + base_unit = unit_registry.m + array = np.linspace(0, 5, 5 * 10).reshape(5, 10).astype(dtype) * base_unit + variable = xr.Variable(("x", "y"), array) + + expected = attach_units(func(strip_units(variable)), extract_units(variable)) + + actual = func(variable) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_searchsorted(self, unit, error, dtype): + base_unit = unit_registry.m + array = np.linspace(0, 5, 10).astype(dtype) * base_unit + variable = xr.Variable("x", array) + + value = 0 * unit + + if error is not None: + with pytest.raises(error): + variable.searchsorted(value) + + return + + expected = strip_units(variable).searchsorted( + strip_units(convert_units(value, {None: base_unit})) + ) + + actual = variable.searchsorted(value) + + assert_units_equal(expected, actual) + np.testing.assert_allclose(expected, actual) + + def test_stack(self, dtype): + array = np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m + variable = xr.Variable(("x", "y"), array) + + expected = attach_units( + strip_units(variable).stack(z=("x", "y")), extract_units(variable) + ) + actual = variable.stack(z=("x", "y")) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + def test_unstack(self, dtype): + array = np.linspace(0, 5, 3 * 10).astype(dtype) * unit_registry.m + variable = xr.Variable("z", array) + + expected = attach_units( + strip_units(variable).unstack(z={"x": 3, "y": 10}), extract_units(variable) + ) + actual = variable.unstack(z={"x": 3, "y": 10}) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_concat(self, unit, error, dtype): + array1 = ( + np.linspace(0, 5, 9 * 10).reshape(3, 6, 5).astype(dtype) * unit_registry.m + ) + array2 = np.linspace(5, 10, 10 * 3).reshape(3, 2, 5).astype(dtype) * unit + + variable = xr.Variable(("x", "y", "z"), array1) + other = xr.Variable(("x", "y", "z"), array2) + + if error is not None: + with pytest.raises(error): + xr.Variable.concat([variable, other], dim="y") + + return + + units = extract_units(variable) + expected = attach_units( + xr.Variable.concat( + [strip_units(variable), strip_units(convert_units(other, units))], + dim="y", + ), + units, + ) + actual = xr.Variable.concat([variable, other], dim="y") + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + def test_set_dims(self, dtype): + array = np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m + variable = xr.Variable(("x", "y"), array) + + dims = {"z": 6, "x": 3, "a": 1, "b": 4, "y": 10} + expected = attach_units( + strip_units(variable).set_dims(dims), extract_units(variable) + ) + actual = variable.set_dims(dims) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + def test_copy(self, dtype): + array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m + other = np.arange(10).astype(dtype) * unit_registry.s + + variable = xr.Variable("x", array) + expected = attach_units( + strip_units(variable).copy(data=strip_units(other)), extract_units(other) + ) + actual = variable.copy(data=other) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + def test_no_conflicts(self, unit, dtype): + base_unit = unit_registry.m + array1 = ( + np.array( + [ + [6.3, 0.3, 0.45], + [np.nan, 0.3, 0.3], + [3.7, np.nan, 0.2], + [9.43, 0.3, 0.7], + ] + ) + * base_unit + ) + array2 = np.array([np.nan, 0.3, np.nan]) * unit + + variable = xr.Variable(("x", "y"), array1) + other = xr.Variable("y", array2) + + expected = strip_units(variable).no_conflicts( + strip_units( + convert_units( + other, {None: base_unit if is_compatible(base_unit, unit) else None} + ) + ) + ) & is_compatible(base_unit, unit) + actual = variable.no_conflicts(other) + + assert expected == actual + + def test_pad(self, dtype): + data = np.arange(4 * 3 * 2).reshape(4, 3, 2).astype(dtype) * unit_registry.m + v = xr.Variable(["x", "y", "z"], data) + + xr_args = [{"x": (2, 1)}, {"y": (0, 3)}, {"x": (3, 1), "z": (2, 0)}] + np_args = [ + ((2, 1), (0, 0), (0, 0)), + ((0, 0), (0, 3), (0, 0)), + ((3, 1), (0, 0), (2, 0)), + ] + for xr_arg, np_arg in zip(xr_args, np_args): + actual = v.pad_with_fill_value(**xr_arg) + expected = xr.Variable( + v.dims, + np.pad( + v.data.astype(float), + np_arg, + mode="constant", + constant_values=np.nan, + ), + ) + xr.testing.assert_identical(expected, actual) + assert_units_equal(expected, actual) + assert isinstance(actual._data, type(v._data)) + + # for the boolean array, we pad False + data = np.full_like(data, False, dtype=bool).reshape(4, 3, 2) + v = xr.Variable(["x", "y", "z"], data) + for xr_arg, np_arg in zip(xr_args, np_args): + actual = v.pad_with_fill_value(fill_value=data.flat[0], **xr_arg) + expected = xr.Variable( + v.dims, + np.pad(v.data, np_arg, mode="constant", constant_values=v.data.flat[0]), + ) + xr.testing.assert_identical(actual, expected) + assert_units_equal(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param( + 1, + DimensionalityError, + id="no_unit", + marks=pytest.mark.xfail( + LooseVersion(pint.__version__) < LooseVersion("0.10.2"), + reason="bug in pint's implementation of np.pad", + ), + ), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_pad_with_fill_value(self, unit, error, dtype): + array = np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m + variable = xr.Variable(("x", "y"), array) + + fill_value = -100 * unit + + func = method("pad_with_fill_value", x=(2, 3), y=(1, 4)) + if error is not None: + with pytest.raises(error): + func(variable, fill_value=fill_value) + + return + + units = extract_units(variable) + expected = attach_units( + func( + strip_units(variable), + fill_value=strip_units(convert_units(fill_value, units)), ), + units, + ) + actual = func(variable, fill_value=fill_value) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + +class TestDataArray: + @pytest.mark.filterwarnings("error:::pint[.*]") + @pytest.mark.parametrize( + "variant", + ( pytest.param( - function("argmin"), - marks=pytest.mark.xfail( - reason="comparison of quantity with ndarrays in nanops not implemented" - ), + "with_dims", + marks=pytest.mark.xfail(reason="units in indexes are not supported"), + ), + pytest.param("with_coords"), + pytest.param("without_coords"), + ), + ) + def test_init(self, variant, dtype): + array = np.linspace(1, 2, 10, dtype=dtype) * unit_registry.m + + x = np.arange(len(array)) * unit_registry.s + y = x.to(unit_registry.ms) + + variants = { + "with_dims": {"x": x}, + "with_coords": {"y": ("x", y)}, + "without_coords": {}, + } + + kwargs = {"data": array, "dims": "x", "coords": variants.get(variant)} + data_array = xr.DataArray(**kwargs) + + assert isinstance(data_array.data, Quantity) + assert all( + { + name: isinstance(coord.data, Quantity) + for name, coord in data_array.coords.items() + }.values() + ) + + @pytest.mark.filterwarnings("error:::pint[.*]") + @pytest.mark.parametrize( + "func", (pytest.param(str, id="str"), pytest.param(repr, id="repr")) + ) + @pytest.mark.parametrize( + "variant", + ( + pytest.param( + "with_dims", + marks=pytest.mark.xfail(reason="units in indexes are not supported"), + ), + pytest.param("with_coords"), + pytest.param("without_coords"), + ), + ) + def test_repr(self, func, variant, dtype): + array = np.linspace(1, 2, 10, dtype=dtype) * unit_registry.m + x = np.arange(len(array)) * unit_registry.s + y = x.to(unit_registry.ms) + + variants = { + "with_dims": {"x": x}, + "with_coords": {"y": ("x", y)}, + "without_coords": {}, + } + + kwargs = {"data": array, "dims": "x", "coords": variants.get(variant)} + data_array = xr.DataArray(**kwargs) + + # FIXME: this just checks that the repr does not raise + # warnings or errors, but does not check the result + func(data_array) + + @pytest.mark.parametrize( + "func", + ( + pytest.param( + function("all"), + marks=pytest.mark.xfail(reason="not implemented by pint yet"), + ), + pytest.param( + function("any"), + marks=pytest.mark.xfail(reason="not implemented by pint yet"), ), + function("argmax"), + function("argmin"), function("max"), function("mean"), pytest.param( function("median"), - marks=pytest.mark.xfail( - reason="np.median on DataArray strips the units" - ), + marks=pytest.mark.xfail(reason="not implemented by xarray"), ), function("min"), pytest.param( function("prod"), marks=pytest.mark.xfail(reason="not implemented by pint yet"), ), - pytest.param( - function("sum"), - marks=pytest.mark.xfail( - reason="comparison of quantity with ndarrays in nanops not implemented" - ), - ), + function("sum"), function("std"), function("var"), function("cumsum"), @@ -1341,18 +2215,8 @@ def test_repr(self, func, variant, dtype): method("any"), marks=pytest.mark.xfail(reason="not implemented by pint yet"), ), - pytest.param( - method("argmax"), - marks=pytest.mark.xfail( - reason="comparison of quantities with ndarrays in nanops not implemented" - ), - ), - pytest.param( - method("argmin"), - marks=pytest.mark.xfail( - reason="comparison of quantities with ndarrays in nanops not implemented" - ), - ), + method("argmax"), + method("argmin"), method("max"), method("mean"), method("median"), @@ -1363,12 +2227,7 @@ def test_repr(self, func, variant, dtype): reason="comparison of quantity with ndarrays in nanops not implemented" ), ), - pytest.param( - method("sum"), - marks=pytest.mark.xfail( - reason="comparison of quantity with ndarrays in nanops not implemented" - ), - ), + method("sum"), method("std"), method("var"), method("cumsum"), @@ -1380,34 +2239,36 @@ def test_repr(self, func, variant, dtype): ids=repr, ) def test_aggregation(self, func, dtype): - array = np.arange(10).astype(dtype) * unit_registry.m - data_array = xr.DataArray(data=array) + array = np.arange(10).astype(dtype) * ( + unit_registry.m if func.name != "cumprod" else unit_registry.dimensionless + ) + data_array = xr.DataArray(data=array, dims="x") - expected = xr.DataArray(data=func(array)) - result = func(data_array) + # units differ based on the applied function, so we need to + # first compute the units + units = extract_units(func(array)) + expected = attach_units(func(strip_units(data_array)), units) + actual = func(data_array) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( "func", ( pytest.param(operator.neg, id="negate"), pytest.param(abs, id="absolute"), - pytest.param( - np.round, - id="round", - marks=pytest.mark.xfail(reason="pint does not implement round"), - ), + pytest.param(np.round, id="round"), ), ) def test_unary_operations(self, func, dtype): array = np.arange(10).astype(dtype) * unit_registry.m data_array = xr.DataArray(data=array) - expected = xr.DataArray(data=func(array)) - result = func(data_array) + units = extract_units(func(array)) + expected = attach_units(func(strip_units(data_array)), units) + actual = func(data_array) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( "func", @@ -1415,23 +2276,18 @@ def test_unary_operations(self, func, dtype): pytest.param(lambda x: 2 * x, id="multiply"), pytest.param(lambda x: x + x, id="add"), pytest.param(lambda x: x[0] + x, id="add scalar"), - pytest.param( - lambda x: x.T @ x, - id="matrix multiply", - marks=pytest.mark.xfail( - reason="pint does not support matrix multiplication yet" - ), - ), + pytest.param(lambda x: x.T @ x, id="matrix multiply"), ), ) def test_binary_operations(self, func, dtype): array = np.arange(10).astype(dtype) * unit_registry.m data_array = xr.DataArray(data=array) - expected = xr.DataArray(data=func(array)) - result = func(data_array) + units = extract_units(func(array)) + expected = attach_units(func(strip_units(data_array)), units) + actual = func(data_array) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( "comparison", @@ -1448,8 +2304,9 @@ def test_binary_operations(self, func, dtype): pytest.param( unit_registry.dimensionless, DimensionalityError, id="dimensionless" ), - pytest.param(unit_registry.s, DimensionalityError, id="incorrect_unit"), - pytest.param(unit_registry.m, None, id="correct_unit"), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), ), ) def test_comparison_operations(self, comparison, unit, error, dtype): @@ -1469,48 +2326,85 @@ def test_comparison_operations(self, comparison, unit, error, dtype): with pytest.raises(error): comparison(data_array, to_compare_with) - else: - result = comparison(data_array, to_compare_with) - # pint compares incompatible arrays to False, so we need to extend - # the multiplication works for both scalar and array results - expected = xr.DataArray( - data=comparison(array, to_compare_with) - * np.ones_like(array, dtype=bool) - ) - assert_equal_with_units(expected, result) + return + + actual = comparison(data_array, to_compare_with) + + expected_units = {None: unit_registry.m if array.check(unit) else None} + expected = array.check(unit) & comparison( + strip_units(data_array), + strip_units(convert_units(to_compare_with, expected_units)), + ) + + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( "units,error", ( pytest.param(unit_registry.dimensionless, None, id="dimensionless"), - pytest.param(unit_registry.m, DimensionalityError, id="incorrect unit"), - pytest.param(unit_registry.degree, None, id="correct unit"), + pytest.param(unit_registry.m, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.degree, None, id="compatible_unit"), ), ) def test_univariate_ufunc(self, units, error, dtype): array = np.arange(10).astype(dtype) * units data_array = xr.DataArray(data=array) + func = function("sin") + if error is not None: with pytest.raises(error): np.sin(data_array) - else: - expected = xr.DataArray(data=np.sin(array)) - result = np.sin(data_array) - assert_equal_with_units(expected, result) + return + + expected = attach_units( + func(strip_units(convert_units(data_array, {None: unit_registry.radians}))), + {None: unit_registry.dimensionless}, + ) + actual = func(data_array) + + assert_equal_with_units(expected, actual) - @pytest.mark.xfail(reason="pint's implementation of `np.maximum` strips units") - def test_bivariate_ufunc(self, dtype): - unit = unit_registry.m - array = np.arange(10).astype(dtype) * unit + @pytest.mark.xfail(reason="xarray's `np.maximum` strips units") + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="without_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_bivariate_ufunc(self, unit, error, dtype): + original_unit = unit_registry.m + array = np.arange(10).astype(dtype) * original_unit data_array = xr.DataArray(data=array) - expected = xr.DataArray(np.maximum(array, 0 * unit)) + if error is not None: + with pytest.raises(error): + np.maximum(data_array, 0 * unit) + + return + + expected_units = {None: original_unit} + expected = attach_units( + np.maximum( + strip_units(data_array), + strip_units(convert_units(0 * unit, expected_units)), + ), + expected_units, + ) - assert_equal_with_units(expected, np.maximum(data_array, 0 * unit)) - assert_equal_with_units(expected, np.maximum(0 * unit, data_array)) + actual = np.maximum(data_array, 0 * unit) + assert_equal_with_units(expected, actual) + + actual = np.maximum(0 * unit, data_array) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize("property", ("T", "imag", "real")) def test_numpy_properties(self, property, dtype): @@ -1518,41 +2412,43 @@ def test_numpy_properties(self, property, dtype): np.arange(5 * 10).astype(dtype) + 1j * np.linspace(-1, 0, 5 * 10).astype(dtype) ).reshape(5, 10) * unit_registry.s + data_array = xr.DataArray(data=array, dims=("x", "y")) - expected = xr.DataArray( - data=getattr(array, property), - dims=("x", "y")[:: 1 if property != "T" else -1], + expected = attach_units( + getattr(strip_units(data_array), property), extract_units(data_array) ) - result = getattr(data_array, property) + actual = getattr(data_array, property) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( "func", - ( - method("conj"), - method("argsort"), - method("conjugate"), - method("round"), - pytest.param( - method("rank", dim="x"), - marks=pytest.mark.xfail(reason="pint does not implement rank yet"), - ), - ), + (method("conj"), method("argsort"), method("conjugate"), method("round")), ids=repr, ) def test_numpy_methods(self, func, dtype): array = np.arange(10).astype(dtype) * unit_registry.m data_array = xr.DataArray(data=array, dims="x") - expected = xr.DataArray(func(array), dims="x") - result = func(data_array) + units = extract_units(func(array)) + expected = attach_units(strip_units(data_array), units) + actual = func(data_array) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( - "func", (method("clip", min=3, max=8), method("searchsorted", v=5)), ids=repr + "func", + ( + method("clip", min=3, max=8), + pytest.param( + method("searchsorted", v=5), + marks=pytest.mark.xfail( + reason="searchsorted somehow requires a undocumented `keys` argument" + ), + ), + ), + ids=repr, ) @pytest.mark.parametrize( "unit,error", @@ -1575,20 +2471,24 @@ def test_numpy_methods_with_args(self, func, unit, error, dtype): key: (value * unit if isinstance(value, scalar_types) else value) for key, value in func.kwargs.items() } - if error is not None: with pytest.raises(error): func(data_array, **kwargs) - else: - expected = func(array, **kwargs) - if func.name not in ["searchsorted"]: - expected = xr.DataArray(data=expected) - result = func(data_array, **kwargs) - if func.name in ["searchsorted"]: - assert np.allclose(expected, result) - else: - assert_equal_with_units(expected, result) + return + + units = extract_units(data_array) + expected_units = extract_units(func(array, **kwargs)) + stripped_kwargs = { + key: strip_units(convert_units(value, units)) + for key, value in kwargs.items() + } + expected = attach_units( + func(strip_units(data_array), **stripped_kwargs), expected_units + ) + actual = func(data_array, **kwargs) + + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( "func", (method("isnull"), method("notnull"), method("count")), ids=repr @@ -1611,9 +2511,9 @@ def test_missing_value_detection(self, func, dtype): data_array = xr.DataArray(data=array, coords={"x": x, "y": y}, dims=("x", "y")) expected = func(strip_units(data_array)) - result = func(data_array) + actual = func(data_array) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.xfail(reason="ffill and bfill lose units in data") @pytest.mark.parametrize("func", (method("ffill"), method("bfill")), ids=repr) @@ -1623,48 +2523,67 @@ def test_missing_value_filling(self, func, dtype): * unit_registry.degK ) x = np.arange(len(array)) - data_array = xr.DataArray(data=array, coords={"x": x}, dims=["x"]) - - result_without_units = func(strip_units(data_array), dim="x") - result = xr.DataArray( - data=result_without_units.data * unit_registry.degK, - coords={"x": x}, - dims=["x"], - ) + data_array = xr.DataArray(data=array, coords={"x": x}, dims="x") expected = attach_units( - func(strip_units(data_array), dim="x"), {"data": unit_registry.degK} + func(strip_units(data_array), dim="x"), extract_units(data_array) ) - result = func(data_array, dim="x") + actual = func(data_array, dim="x") - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) - @pytest.mark.xfail(reason="fillna drops the unit") @pytest.mark.parametrize( - "fill_value", + "unit,error", ( + pytest.param(1, DimensionalityError, id="no_unit"), pytest.param( - -1, - id="python scalar", - marks=pytest.mark.xfail( - reason="python scalar cannot be converted using astype()" - ), + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param( + unit_registry.cm, + None, + id="compatible_unit", + marks=pytest.mark.xfail(reason="fillna converts to value's unit"), ), - pytest.param(np.array(-1), id="numpy scalar"), - pytest.param(np.array([-1]), id="numpy array"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "fill_value", + ( + pytest.param(-1, id="python_scalar"), + pytest.param(np.array(-1), id="numpy_scalar"), + pytest.param(np.array([-1]), id="numpy_array"), ), ) - def test_fillna(self, fill_value, dtype): - unit = unit_registry.m - array = np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) * unit + def test_fillna(self, fill_value, unit, error, dtype): + original_unit = unit_registry.m + array = ( + np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) + * original_unit + ) data_array = xr.DataArray(data=array) + func = method("fillna") + + value = fill_value * unit + if error is not None: + with pytest.raises(error): + func(data_array, value=value) + + return + + units = extract_units(data_array) expected = attach_units( - strip_units(data_array).fillna(value=fill_value), {"data": unit} + func( + strip_units(data_array), value=strip_units(convert_units(value, units)) + ), + units, ) - result = data_array.fillna(value=fill_value * unit) + actual = func(data_array, value=value) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) def test_dropna(self, dtype): array = ( @@ -1674,22 +2593,26 @@ def test_dropna(self, dtype): x = np.arange(len(array)) data_array = xr.DataArray(data=array, coords={"x": x}, dims=["x"]) - expected = attach_units( - strip_units(data_array).dropna(dim="x"), {"data": unit_registry.m} - ) - result = data_array.dropna(dim="x") + units = extract_units(data_array) + expected = attach_units(strip_units(data_array).dropna(dim="x"), units) + actual = data_array.dropna(dim="x") - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) - @pytest.mark.xfail(reason="pint does not implement `numpy.isin`") @pytest.mark.parametrize( "unit", ( - pytest.param(1, id="no_unit"), + pytest.param( + 1, + id="no_unit", + marks=pytest.mark.xfail( + reason="pint's isin implementation does not work well with mixed args" + ), + ), pytest.param(unit_registry.dimensionless, id="dimensionless"), pytest.param(unit_registry.s, id="incompatible_unit"), pytest.param(unit_registry.cm, id="compatible_unit"), - pytest.param(unit_registry.m, id="same_unit"), + pytest.param(unit_registry.m, id="identical_unit"), ), ) def test_isin(self, unit, dtype): @@ -1702,33 +2625,26 @@ def test_isin(self, unit, dtype): raw_values = np.array([1.4, np.nan, 2.3]).astype(dtype) values = raw_values * unit - result_without_units = strip_units(data_array).isin(raw_values) - if unit != unit_registry.m: - result_without_units[:] = False - result_with_units = data_array.isin(values) + units = {None: unit_registry.m if array.check(unit) else None} + expected = strip_units(data_array).isin( + strip_units(convert_units(values, units)) + ) & array.check(unit) + actual = data_array.isin(values) - assert_equal_with_units(result_without_units, result_with_units) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( "variant", ( pytest.param( "masking", - marks=pytest.mark.xfail(reason="nan not compatible with quantity"), - ), - pytest.param( - "replacing_scalar", - marks=pytest.mark.xfail(reason="scalar not convertible using astype"), - ), - pytest.param( - "replacing_array", - marks=pytest.mark.xfail( - reason="replacing using an array drops the units" - ), + marks=pytest.mark.xfail(reason="array(nan) is not a quantity"), ), + "replacing_scalar", + "replacing_array", pytest.param( "dropping", - marks=pytest.mark.xfail(reason="nan not compatible with quantity"), + marks=pytest.mark.xfail(reason="array(nan) is not a quantity"), ), ), ) @@ -1741,13 +2657,10 @@ def test_isin(self, unit, dtype): ), pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), pytest.param(unit_registry.cm, None, id="compatible_unit"), - pytest.param(unit_registry.m, None, id="same_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), ), ) def test_where(self, variant, unit, error, dtype): - def _strip_units(mapping): - return {key: array_strip_units(value) for key, value in mapping.items()} - original_unit = unit_registry.m array = np.linspace(0, 1, 10).astype(dtype) * original_unit @@ -1761,20 +2674,29 @@ def _strip_units(mapping): "replacing_array": {"cond": condition, "other": other}, "dropping": {"cond": condition, "drop": True}, } - kwargs = variant_kwargs.get(variant) - kwargs_without_units = _strip_units(kwargs) + kwargs = variant_kwargs.get(variant) + kwargs_without_units = { + key: strip_units( + convert_units( + value, {None: original_unit if array.check(unit) else None} + ) + ) + for key, value in kwargs.items() + } if variant not in ("masking", "dropping") and error is not None: with pytest.raises(error): data_array.where(**kwargs) - else: - expected = attach_units( - strip_units(array).where(**kwargs_without_units), - {"data": original_unit}, - ) - result = data_array.where(**kwargs) - assert_equal_with_units(expected, result) + return + + expected = attach_units( + strip_units(data_array).where(**kwargs_without_units), + extract_units(data_array), + ) + actual = data_array.where(**kwargs) + + assert_equal_with_units(expected, actual) @pytest.mark.xfail(reason="interpolate strips units") def test_interpolate_na(self, dtype): @@ -1785,14 +2707,12 @@ def test_interpolate_na(self, dtype): x = np.arange(len(array)) data_array = xr.DataArray(data=array, coords={"x": x}, dims="x").astype(dtype) - expected = attach_units( - strip_units(data_array).interpolate_na(dim="x"), {"data": unit_registry.m} - ) - result = data_array.interpolate_na(dim="x") + units = extract_units(data_array) + expected = attach_units(strip_units(data_array).interpolate_na(dim="x"), units) + actual = data_array.interpolate_na(dim="x") - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) - @pytest.mark.xfail(reason="uses DataArray.where, which currently fails") @pytest.mark.parametrize( "unit,error", ( @@ -1801,8 +2721,18 @@ def test_interpolate_na(self, dtype): unit_registry.dimensionless, DimensionalityError, id="dimensionless" ), pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), - pytest.param(unit_registry.cm, None, id="compatible_unit"), - pytest.param(unit_registry.m, None, id="identical_unit"), + pytest.param( + unit_registry.cm, + None, + id="compatible_unit", + marks=pytest.mark.xfail(reason="depends on reindex"), + ), + pytest.param( + unit_registry.m, + None, + id="identical_unit", + marks=pytest.mark.xfail(reason="depends on reindex"), + ), ), ) def test_combine_first(self, unit, error, dtype): @@ -1819,14 +2749,19 @@ def test_combine_first(self, unit, error, dtype): if error is not None: with pytest.raises(error): data_array.combine_first(other) - else: - expected = attach_units( - strip_units(data_array).combine_first(strip_units(other)), - {"data": unit_registry.m}, - ) - result = data_array.combine_first(other) - assert_equal_with_units(expected, result) + return + + units = extract_units(data_array) + expected = attach_units( + strip_units(data_array).combine_first( + strip_units(convert_units(other, units)) + ), + units, + ) + actual = data_array.combine_first(other) + + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( "unit", @@ -1834,11 +2769,7 @@ def test_combine_first(self, unit, error, dtype): pytest.param(1, id="no_unit"), pytest.param(unit_registry.dimensionless, id="dimensionless"), pytest.param(unit_registry.s, id="incompatible_unit"), - pytest.param( - unit_registry.cm, - id="compatible_unit", - marks=pytest.mark.xfail(reason="identical does not check units yet"), - ), + pytest.param(unit_registry.cm, id="compatible_unit"), pytest.param(unit_registry.m, id="identical_unit"), ), ) @@ -1854,53 +2785,51 @@ def test_combine_first(self, unit, error, dtype): ) @pytest.mark.parametrize("func", (method("equals"), method("identical")), ids=repr) def test_comparisons(self, func, variation, unit, dtype): + def is_compatible(a, b): + a = a if a is not None else 1 + b = b if b is not None else 1 + quantity = np.arange(5) * a + + return a == b or quantity.check(b) + data = np.linspace(0, 5, 10).astype(dtype) coord = np.arange(len(data)).astype(dtype) base_unit = unit_registry.m - quantity = data * base_unit - x = coord * base_unit - y = coord * base_unit - - units = { - "data": (unit, base_unit, base_unit), - "dims": (base_unit, unit, base_unit), - "coords": (base_unit, base_unit, unit), + array = data * (base_unit if variation == "data" else 1) + x = coord * (base_unit if variation == "dims" else 1) + y = coord * (base_unit if variation == "coords" else 1) + + variations = { + "data": (unit, 1, 1), + "dims": (1, unit, 1), + "coords": (1, 1, unit), } - data_unit, dim_unit, coord_unit = units.get(variation) + data_unit, dim_unit, coord_unit = variations.get(variation) - data_array = xr.DataArray( - data=quantity, coords={"x": x, "y": ("x", y)}, dims="x" - ) + data_array = xr.DataArray(data=array, coords={"x": x, "y": ("x", y)}, dims="x") other = attach_units( - strip_units(data_array), - { - None: (data_unit, base_unit if quantity.check(data_unit) else None), - "x": (dim_unit, base_unit if x.check(dim_unit) else None), - "y": (coord_unit, base_unit if y.check(coord_unit) else None), - }, + strip_units(data_array), {None: data_unit, "x": dim_unit, "y": coord_unit} ) - # TODO: test dim coord once indexes leave units intact - # also, express this in terms of calls on the raw data array - # and then check the units - equal_arrays = ( - np.all(quantity == other.data) - and (np.all(x == other.x.data) or True) # dims can't be checked yet - and np.all(y == other.y.data) - ) - equal_units = ( - data_unit == unit_registry.m - and coord_unit == unit_registry.m - and dim_unit == unit_registry.m + units = extract_units(data_array) + other_units = extract_units(other) + + equal_arrays = all( + is_compatible(units[name], other_units[name]) for name in units.keys() + ) and ( + strip_units(data_array).equals( + strip_units(convert_units(other, extract_units(data_array))) + ) ) + equal_units = units == other_units expected = equal_arrays and (func.name != "identical" or equal_units) - result = func(data_array, other) - assert expected == result + actual = func(data_array, other) + + assert expected == actual - @pytest.mark.xfail(reason="blocked by `where`") @pytest.mark.parametrize( "unit", ( @@ -1926,9 +2855,9 @@ def test_broadcast_like(self, unit, dtype): expected = attach_units( strip_units(arr1).broadcast_like(strip_units(arr2)), extract_units(arr1) ) - result = arr1.broadcast_like(arr2) + actual = arr1.broadcast_like(arr2) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( "unit", @@ -1942,19 +2871,21 @@ def test_broadcast_like(self, unit, dtype): ) def test_broadcast_equals(self, unit, dtype): left_array = np.ones(shape=(2, 2), dtype=dtype) * unit_registry.m - right_array = array_attach_units( - np.ones(shape=(2,), dtype=dtype), - unit, - convert_from=unit_registry.m if left_array.check(unit) else None, - ) + right_array = np.ones(shape=(2,), dtype=dtype) * unit left = xr.DataArray(data=left_array, dims=("x", "y")) right = xr.DataArray(data=right_array, dims="x") - expected = np.all(left_array == right_array[:, None]) - result = left.broadcast_equals(right) + units = { + **extract_units(left), + **({} if left_array.check(unit) else {None: None}), + } + expected = strip_units(left).broadcast_equals( + strip_units(convert_units(right, units)) + ) & left_array.check(unit) + actual = left.broadcast_equals(right) - assert expected == result + assert expected == actual @pytest.mark.parametrize( "func", @@ -1969,16 +2900,11 @@ def test_broadcast_equals(self, unit, dtype): dim={"z": np.linspace(10, 20, 12) * unit_registry.s}, axis=1, ), - method("drop_sel", labels="x"), + method("drop_vars", "x"), method("reset_coords", names="x2"), method("copy"), - pytest.param( - method("astype", np.float32), - marks=pytest.mark.xfail(reason="units get stripped"), - ), - pytest.param( - method("item", 1), marks=pytest.mark.xfail(reason="units get stripped") - ), + method("astype", np.float32), + method("item", 1), ), ids=repr, ) @@ -2001,67 +2927,38 @@ def test_content_manipulation(self, func, dtype): stripped_kwargs = { key: array_strip_units(value) for key, value in func.kwargs.items() } - expected = attach_units( - func(strip_units(data_array), **stripped_kwargs), - { - "data": quantity.units, - "x": x.units, - "x_mm": x2.units, - "x2": x2.units, - "y": y.units, - }, - ) - result = func(data_array) + units = {**{"x_mm": x2.units, "x2": x2.units}, **extract_units(data_array)} + + expected = attach_units(func(strip_units(data_array), **stripped_kwargs), units) + actual = func(data_array) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( - "func", - ( - pytest.param( - method("drop_sel", labels=dict(x=np.array([1, 5]))), - marks=pytest.mark.xfail( - reason="selecting using incompatible units does not raise" - ), - ), - pytest.param(method("copy", data=np.arange(20))), - ), - ids=repr, + "func", (pytest.param(method("copy", data=np.arange(20))),), ids=repr ) @pytest.mark.parametrize( - "unit,error", + "unit", ( - pytest.param(1, DimensionalityError, id="no_unit"), - pytest.param( - unit_registry.dimensionless, DimensionalityError, id="dimensionless" - ), - pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), - pytest.param(unit_registry.cm, KeyError, id="compatible_unit"), - pytest.param(unit_registry.m, None, id="identical_unit"), + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.degK, id="with_unit"), ), ) - def test_content_manipulation_with_units(self, func, unit, error, dtype): + def test_content_manipulation_with_units(self, func, unit, dtype): quantity = np.linspace(0, 10, 20, dtype=dtype) * unit_registry.pascal x = np.arange(len(quantity)) * unit_registry.m - data_array = xr.DataArray(name="data", data=quantity, coords={"x": x}, dims="x") + data_array = xr.DataArray(data=quantity, coords={"x": x}, dims="x") - kwargs = { - key: (value * unit if isinstance(value, np.ndarray) else value) - for key, value in func.kwargs.items() - } - stripped_kwargs = func.kwargs + kwargs = {key: value * unit for key, value in func.kwargs.items()} expected = attach_units( - func(strip_units(data_array), **stripped_kwargs), - {"data": quantity.units if func.name == "drop_sel" else unit, "x": x.units}, + func(strip_units(data_array)), {None: unit, "x": x.units} ) - if error is not None and func.name == "drop_sel": - with pytest.raises(error): - func(data_array, **kwargs) - else: - result = func(data_array, **kwargs) - assert_equal_with_units(expected, result) + + actual = func(data_array, **kwargs) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( "indices", @@ -2074,95 +2971,152 @@ def test_isel(self, indices, dtype): array = np.arange(10).astype(dtype) * unit_registry.s x = np.arange(len(array)) * unit_registry.m - data_array = xr.DataArray(data=array, coords={"x": x}, dims=["x"]) + data_array = xr.DataArray(data=array, coords={"x": x}, dims="x") expected = attach_units( - strip_units(data_array).isel(x=indices), - {"data": unit_registry.s, "x": unit_registry.m}, + strip_units(data_array).isel(x=indices), extract_units(data_array) ) - result = data_array.isel(x=indices) + actual = data_array.isel(x=indices) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) - @pytest.mark.xfail( - reason="xarray does not support duck arrays in dimension coordinates" - ) + @pytest.mark.xfail(reason="indexes don't support units") @pytest.mark.parametrize( - "values", + "raw_values", ( - pytest.param(12, id="single value"), - pytest.param([10, 5, 13], id="list of multiple values"), - pytest.param(np.array([9, 3, 7, 12]), id="array of multiple values"), + pytest.param(10, id="single_value"), + pytest.param([10, 5, 13], id="list_of_values"), + pytest.param(np.array([9, 3, 7, 12]), id="array_of_values"), ), ) @pytest.mark.parametrize( - "units,error", + "unit,error", ( - pytest.param(1, KeyError, id="no units"), + pytest.param(1, KeyError, id="no_units"), pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"), - pytest.param(unit_registry.degree, KeyError, id="incorrect unit"), - pytest.param(unit_registry.s, None, id="correct unit"), + pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"), + pytest.param(unit_registry.dm, KeyError, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), ), ) - def test_sel(self, values, units, error, dtype): + def test_sel(self, raw_values, unit, error, dtype): array = np.linspace(5, 10, 20).astype(dtype) * unit_registry.m - x = np.arange(len(array)) * unit_registry.s - data_array = xr.DataArray(data=array, coords={"x": x}, dims=["x"]) + x = np.arange(len(array)) * unit_registry.m + data_array = xr.DataArray(data=array, coords={"x": x}, dims="x") - values_with_units = values * units + values = raw_values * unit - if error is not None: + if error is not None and not ( + isinstance(raw_values, (int, float)) and x.check(unit) + ): with pytest.raises(error): - data_array.sel(x=values_with_units) - else: - result_array = array[values] - result_data_array = data_array.sel(x=values_with_units) - assert_equal_with_units(result_array, result_data_array) + data_array.sel(x=values) + + return + + expected = attach_units( + strip_units(data_array).sel( + x=strip_units(convert_units(values, {None: array.units})) + ), + extract_units(data_array), + ) + actual = data_array.sel(x=values) + assert_equal_with_units(expected, actual) - @pytest.mark.xfail( - reason="xarray does not support duck arrays in dimension coordinates" + @pytest.mark.xfail(reason="indexes don't support units") + @pytest.mark.parametrize( + "raw_values", + ( + pytest.param(10, id="single_value"), + pytest.param([10, 5, 13], id="list_of_values"), + pytest.param(np.array([9, 3, 7, 12]), id="array_of_values"), + ), + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, KeyError, id="no_units"), + pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"), + pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"), + pytest.param(unit_registry.dm, KeyError, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), ) + def test_loc(self, raw_values, unit, error, dtype): + array = np.linspace(5, 10, 20).astype(dtype) * unit_registry.m + x = np.arange(len(array)) * unit_registry.m + data_array = xr.DataArray(data=array, coords={"x": x}, dims="x") + + values = raw_values * unit + + if error is not None and not ( + isinstance(raw_values, (int, float)) and x.check(unit) + ): + with pytest.raises(error): + data_array.loc[{"x": values}] + + return + + expected = attach_units( + strip_units(data_array).loc[ + {"x": strip_units(convert_units(values, {None: array.units}))} + ], + extract_units(data_array), + ) + actual = data_array.loc[{"x": values}] + assert_equal_with_units(expected, actual) + + @pytest.mark.xfail(reason="indexes don't support units") @pytest.mark.parametrize( - "values", + "raw_values", ( - pytest.param(12, id="single value"), - pytest.param([10, 5, 13], id="list of multiple values"), - pytest.param(np.array([9, 3, 7, 12]), id="array of multiple values"), + pytest.param(10, id="single_value"), + pytest.param([10, 5, 13], id="list_of_values"), + pytest.param(np.array([9, 3, 7, 12]), id="array_of_values"), ), ) @pytest.mark.parametrize( - "units,error", + "unit,error", ( - pytest.param(1, KeyError, id="no units"), + pytest.param(1, KeyError, id="no_units"), pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"), - pytest.param(unit_registry.degree, KeyError, id="incorrect unit"), - pytest.param(unit_registry.s, None, id="correct unit"), + pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"), + pytest.param(unit_registry.dm, KeyError, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), ), ) - def test_loc(self, values, units, error, dtype): + def test_drop_sel(self, raw_values, unit, error, dtype): array = np.linspace(5, 10, 20).astype(dtype) * unit_registry.m - x = np.arange(len(array)) * unit_registry.s - data_array = xr.DataArray(data=array, coords={"x": x}, dims=["x"]) + x = np.arange(len(array)) * unit_registry.m + data_array = xr.DataArray(data=array, coords={"x": x}, dims="x") - values_with_units = values * units + values = raw_values * unit - if error is not None: + if error is not None and not ( + isinstance(raw_values, (int, float)) and x.check(unit) + ): with pytest.raises(error): - data_array.loc[values_with_units] - else: - result_array = array[values] - result_data_array = data_array.loc[values_with_units] - assert_equal_with_units(result_array, result_data_array) + data_array.drop_sel(x=values) + + return + + expected = attach_units( + strip_units(data_array).drop_sel( + x=strip_units(convert_units(values, {None: x.units})) + ), + extract_units(data_array), + ) + actual = data_array.drop_sel(x=values) + assert_equal_with_units(expected, actual) - @pytest.mark.xfail(reason="tries to coerce using asarray") @pytest.mark.parametrize( "shape", ( - pytest.param((10, 20), id="nothing squeezable"), - pytest.param((10, 20, 1), id="last dimension squeezable"), - pytest.param((10, 1, 20), id="middle dimension squeezable"), - pytest.param((1, 10, 20), id="first dimension squeezable"), - pytest.param((1, 10, 1, 20), id="first and last dimension squeezable"), + pytest.param((10, 20), id="nothing_squeezable"), + pytest.param((10, 20, 1), id="last_dimension_squeezable"), + pytest.param((10, 1, 20), id="middle_dimension_squeezable"), + pytest.param((1, 10, 20), id="first_dimension_squeezable"), + pytest.param((1, 10, 1, 20), id="first_and_last_dimension_squeezable"), ), ) def test_squeeze(self, shape, dtype): @@ -2177,38 +3131,27 @@ def test_squeeze(self, shape, dtype): data=array, coords=coords, dims=tuple(names[: len(shape)]) ) - result_array = array.squeeze() - result_data_array = data_array.squeeze() - assert_equal_with_units(result_array, result_data_array) + expected = attach_units( + strip_units(data_array).squeeze(), extract_units(data_array) + ) + actual = data_array.squeeze() + assert_equal_with_units(expected, actual) # try squeezing the dimensions separately names = tuple(dim for dim, coord in coords.items() if len(coord) == 1) for index, name in enumerate(names): - assert_equal_with_units( - np.squeeze(array, axis=index), data_array.squeeze(dim=name) + expected = attach_units( + strip_units(data_array).squeeze(dim=name), extract_units(data_array) ) + actual = data_array.squeeze(dim=name) + assert_equal_with_units(expected, actual) - @pytest.mark.xfail( - reason="indexes strip units and head / tail / thin only support integers" - ) - @pytest.mark.parametrize( - "unit,error", - ( - pytest.param(1, DimensionalityError, id="no_unit"), - pytest.param( - unit_registry.dimensionless, DimensionalityError, id="dimensionless" - ), - pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), - pytest.param(unit_registry.cm, None, id="compatible_unit"), - pytest.param(unit_registry.m, None, id="identical_unit"), - ), - ) @pytest.mark.parametrize( "func", (method("head", x=7, y=3), method("tail", x=7, y=3), method("thin", x=7, y=3)), ids=repr, ) - def test_head_tail_thin(self, func, unit, error, dtype): + def test_head_tail_thin(self, func, dtype): array = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK coords = { @@ -2216,27 +3159,24 @@ def test_head_tail_thin(self, func, unit, error, dtype): "y": np.arange(5) * unit_registry.m, } - arr = xr.DataArray(data=array, coords=coords, dims=("x", "y")) - - kwargs = {name: value * unit for name, value in func.kwargs.items()} - - if error is not None: - with pytest.raises(error): - func(arr, **kwargs) - - return + data_array = xr.DataArray(data=array, coords=coords, dims=("x", "y")) - expected = attach_units(func(strip_units(arr)), extract_units(arr)) - result = func(arr, **kwargs) + expected = attach_units( + func(strip_units(data_array)), extract_units(data_array) + ) + actual = func(data_array) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) + @pytest.mark.xfail(reason="indexes don't support units") @pytest.mark.parametrize( "unit,error", ( - pytest.param(1, None, id="no_unit"), - pytest.param(unit_registry.dimensionless, None, id="dimensionless"), - pytest.param(unit_registry.s, None, id="incompatible_unit"), + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), pytest.param(unit_registry.cm, None, id="compatible_unit"), pytest.param(unit_registry.m, None, id="identical_unit"), ), @@ -2254,24 +3194,29 @@ def test_interp(self, unit, error): if error is not None: with pytest.raises(error): data_array.interp(x=new_coords) - else: - new_coords_ = ( - new_coords.magnitude if hasattr(new_coords, "magnitude") else new_coords - ) - result_array = strip_units(data_array).interp( - x=new_coords_ * unit_registry.degK - ) - result_data_array = data_array.interp(x=new_coords) - assert_equal_with_units(result_array, result_data_array) + return + + units = extract_units(data_array) + expected = attach_units( + strip_units(data_array).interp( + x=strip_units(convert_units(new_coords, {None: unit_registry.m})) + ), + units, + ) + actual = data_array.interp(x=new_coords) + + assert_equal_with_units(expected, actual) - @pytest.mark.xfail(reason="tries to coerce using asarray") + @pytest.mark.xfail(reason="indexes strip units") @pytest.mark.parametrize( "unit,error", ( - pytest.param(1, None, id="no_unit"), - pytest.param(unit_registry.dimensionless, None, id="dimensionless"), - pytest.param(unit_registry.s, None, id="incompatible_unit"), + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), pytest.param(unit_registry.cm, None, id="compatible_unit"), pytest.param(unit_registry.m, None, id="identical_unit"), ), @@ -2284,43 +3229,46 @@ def test_interp_like(self, unit, error): } data_array = xr.DataArray(array, coords=coords, dims=("x", "y")) - new_data_array = xr.DataArray( - data=np.empty((20, 10)), + other = xr.DataArray( + data=np.empty((20, 10)) * unit_registry.degK, coords={"x": np.arange(20) * unit, "y": np.arange(10) * unit}, dims=("x", "y"), ) if error is not None: with pytest.raises(error): - data_array.interp_like(new_data_array) - else: - result_array = ( - xr.DataArray( - data=array.magnitude, - coords={name: value.magnitude for name, value in coords.items()}, - dims=("x", "y"), - ).interp_like(strip_units(new_data_array)) - * unit_registry.degK - ) - result_data_array = data_array.interp_like(new_data_array) + data_array.interp_like(other) - assert_equal_with_units(result_array, result_data_array) + return - @pytest.mark.xfail( - reason="pint does not implement np.result_type in __array_function__ yet" - ) + units = extract_units(data_array) + expected = attach_units( + strip_units(data_array).interp_like( + strip_units(convert_units(other, units)) + ), + units, + ) + actual = data_array.interp_like(other) + + assert_equal_with_units(expected, actual) + + @pytest.mark.xfail(reason="indexes don't support units") @pytest.mark.parametrize( "unit,error", ( - pytest.param(1, None, id="no_unit"), - pytest.param(unit_registry.dimensionless, None, id="dimensionless"), - pytest.param(unit_registry.s, None, id="incompatible_unit"), + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), pytest.param(unit_registry.cm, None, id="compatible_unit"), pytest.param(unit_registry.m, None, id="identical_unit"), ), ) - def test_reindex(self, unit, error): - array = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK + def test_reindex(self, unit, error, dtype): + array = ( + np.linspace(1, 2, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK + ) new_coords = (np.arange(10) + 0.5) * unit coords = { "x": np.arange(10) * unit_registry.m, @@ -2328,65 +3276,70 @@ def test_reindex(self, unit, error): } data_array = xr.DataArray(array, coords=coords, dims=("x", "y")) + func = method("reindex") if error is not None: with pytest.raises(error): - data_array.interp(x=new_coords) - else: - result_array = strip_units(data_array).reindex( - x=( - new_coords.magnitude - if hasattr(new_coords, "magnitude") - else new_coords - ) - * unit_registry.degK - ) - result_data_array = data_array.reindex(x=new_coords) + func(data_array, x=new_coords) - assert_equal_with_units(result_array, result_data_array) + return - @pytest.mark.xfail( - reason="pint does not implement np.result_type in __array_function__ yet" - ) + expected = attach_units( + func( + strip_units(data_array), + x=strip_units(convert_units(new_coords, {None: unit_registry.m})), + ), + {None: unit_registry.degK}, + ) + actual = func(data_array, x=new_coords) + + assert_equal_with_units(expected, actual) + + @pytest.mark.xfail(reason="indexes don't support units") @pytest.mark.parametrize( "unit,error", ( - pytest.param(1, None, id="no_unit"), - pytest.param(unit_registry.dimensionless, None, id="dimensionless"), - pytest.param(unit_registry.s, None, id="incompatible_unit"), + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), pytest.param(unit_registry.cm, None, id="compatible_unit"), pytest.param(unit_registry.m, None, id="identical_unit"), ), ) - def test_reindex_like(self, unit, error): - array = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK + def test_reindex_like(self, unit, error, dtype): + array = ( + np.linspace(1, 2, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK + ) coords = { "x": (np.arange(10) + 0.3) * unit_registry.m, "y": (np.arange(5) + 0.3) * unit_registry.m, } data_array = xr.DataArray(array, coords=coords, dims=("x", "y")) - new_data_array = xr.DataArray( - data=np.empty((20, 10)), + other = xr.DataArray( + data=np.empty((20, 10)) * unit_registry.degK, coords={"x": np.arange(20) * unit, "y": np.arange(10) * unit}, dims=("x", "y"), ) if error is not None: with pytest.raises(error): - data_array.reindex_like(new_data_array) - else: - expected = attach_units( - strip_units(data_array).reindex_like(strip_units(new_data_array)), - { - "data": unit_registry.degK, - "x": unit_registry.m, - "y": unit_registry.m, - }, - ) - result = data_array.reindex_like(new_data_array) + data_array.reindex_like(other) + + return + + units = extract_units(data_array) + expected = attach_units( + strip_units(data_array).reindex_like( + strip_units(convert_units(other, units)) + ), + units, + ) + actual = data_array.reindex_like(other) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( "func", @@ -2406,11 +3359,11 @@ def test_stacking_stacked(self, func, dtype): stacked = data_array.stack(z=("x", "y")) expected = attach_units(func(strip_units(stacked)), {"data": unit_registry.m}) - result = func(stacked) + actual = func(stacked) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) - @pytest.mark.xfail(reason="indexes strip the label units") + @pytest.mark.xfail(reason="indexes don't support units") def test_to_unstacked_dataset(self, dtype): array = ( np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) @@ -2429,13 +3382,9 @@ def test_to_unstacked_dataset(self, dtype): func(strip_units(data_array)), {"y": y.units, **dict(zip(x.magnitude, [array.units] * len(y)))}, ).rename({elem.magnitude: elem for elem in x}) - result = func(data_array) + actual = func(data_array) - print(data_array, expected, result, sep="\n") - - assert_equal_with_units(expected, result) - - assert False + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( "func", @@ -2446,10 +3395,7 @@ def test_to_unstacked_dataset(self, dtype): pytest.param( method("shift", x=2), marks=pytest.mark.xfail(reason="strips units") ), - pytest.param( - method("roll", x=2, roll_coords=False), - marks=pytest.mark.xfail(reason="strips units"), - ), + method("roll", x=2, roll_coords=False), method("sortby", "x2"), ), ids=repr, @@ -2471,12 +3417,10 @@ def test_stacking_reordering(self, func, dtype): dims=("x", "y", "z"), ) - expected = attach_units( - func(strip_units(data_array)), {"data": unit_registry.m} - ) - result = func(data_array) + expected = attach_units(func(strip_units(data_array)), {None: unit_registry.m}) + actual = func(data_array) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( "func", @@ -2486,18 +3430,15 @@ def test_stacking_reordering(self, func, dtype): method("integrate", dim="x"), pytest.param( method("quantile", q=[0.25, 0.75]), - marks=pytest.mark.xfail( - reason="pint does not implement nanpercentile yet" - ), - ), - pytest.param( - method("reduce", func=np.sum, dim="x"), - marks=pytest.mark.xfail(reason="strips units"), + marks=pytest.mark.xfail(reason="nanquantile not implemented"), ), + method("reduce", func=np.sum, dim="x"), pytest.param( lambda x: x.dot(x), id="method_dot", - marks=pytest.mark.xfail(reason="pint does not implement einsum"), + marks=pytest.mark.xfail( + reason="pint does not implement the dot method" + ), ), ), ids=repr, @@ -2511,30 +3452,35 @@ def test_computation(self, func, dtype): y = np.arange(array.shape[1]) * unit_registry.s data_array = xr.DataArray(data=array, coords={"x": x, "y": y}, dims=("x", "y")) - units = extract_units(data_array) + + # we want to make sure the output unit is correct + units = { + **extract_units(data_array), + **( + {} + if isinstance(func, (function, method)) + else extract_units(func(array.reshape(-1))) + ), + } expected = attach_units(func(strip_units(data_array)), units) - result = func(data_array) + actual = func(data_array) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( "func", ( - pytest.param( - method("groupby", "y"), marks=pytest.mark.xfail(reason="strips units") - ), - pytest.param( - method("groupby_bins", "y", bins=4), - marks=pytest.mark.xfail(reason="strips units"), - ), + method("groupby", "x"), + method("groupby_bins", "y", bins=4), method("coarsen", y=2), pytest.param( - method("rolling", y=3), marks=pytest.mark.xfail(reason="strips units") + method("rolling", y=3), + marks=pytest.mark.xfail(reason="rolling strips units"), ), pytest.param( method("rolling_exp", y=3), - marks=pytest.mark.xfail(reason="strips units"), + marks=pytest.mark.xfail(reason="units not supported by numbagg"), ), ), ids=repr, @@ -2544,18 +3490,17 @@ def test_computation_objects(self, func, dtype): np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * unit_registry.m ) - x = np.arange(array.shape[0]) * unit_registry.m + x = np.array([0, 0, 1, 2, 2]) * unit_registry.m y = np.arange(array.shape[1]) * 3 * unit_registry.s data_array = xr.DataArray(data=array, coords={"x": x, "y": y}, dims=("x", "y")) units = extract_units(data_array) expected = attach_units(func(strip_units(data_array)).mean(), units) - result = func(data_array).mean() + actual = func(data_array).mean() - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) - @pytest.mark.xfail(reason="strips units") def test_resample(self, dtype): array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m @@ -2566,22 +3511,19 @@ def test_resample(self, dtype): func = method("resample", time="6m") expected = attach_units(func(strip_units(data_array)).mean(), units) - result = func(data_array).mean() + actual = func(data_array).mean() - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( "func", ( - pytest.param( - method("assign_coords", {"z": (["x"], np.arange(5) * unit_registry.s)}), - marks=pytest.mark.xfail(reason="strips units"), - ), - pytest.param(method("first")), - pytest.param(method("last")), + method("assign_coords", z=(["x"], np.arange(5) * unit_registry.s)), + method("first"), + method("last"), pytest.param( method("quantile", q=[0.25, 0.5, 0.75], dim="x"), - marks=pytest.mark.xfail(reason="strips units"), + marks=pytest.mark.xfail(reason="nanquantile not implemented"), ), ), ids=repr, @@ -2595,12 +3537,22 @@ def test_grouped_operations(self, func, dtype): y = np.arange(array.shape[1]) * 3 * unit_registry.s data_array = xr.DataArray(data=array, coords={"x": x, "y": y}, dims=("x", "y")) - units = extract_units(data_array) + units = {**extract_units(data_array), **{"z": unit_registry.s, "q": None}} - expected = attach_units(func(strip_units(data_array).groupby("y")), units) - result = func(data_array.groupby("y")) + stripped_kwargs = { + key: ( + strip_units(value) + if not isinstance(value, tuple) + else tuple(strip_units(elem) for elem in value) + ) + for key, value in func.kwargs.items() + } + expected = attach_units( + func(strip_units(data_array).groupby("y"), **stripped_kwargs), units + ) + actual = func(data_array.groupby("y")) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) class TestDataset: @@ -2620,10 +3572,7 @@ class TestDataset: "shared", ( "nothing", - pytest.param( - "dims", - marks=pytest.mark.xfail(reason="reindex does not work with pint yet"), - ), + pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")), pytest.param( "coords", marks=pytest.mark.xfail(reason="reindex does not work with pint yet"), @@ -2674,7 +3623,7 @@ def test_init(self, shared, unit, error, dtype): return - result = xr.Dataset(data_vars={"a": arr1, "b": arr2}) + actual = xr.Dataset(data_vars={"a": arr1, "b": arr2}) expected_units = { "a": a.units, @@ -2688,7 +3637,7 @@ def test_init(self, shared, unit, error, dtype): xr.Dataset(data_vars={"a": strip_units(arr1), "b": strip_units(arr2)}), expected_units, ) - assert_equal_with_units(result, expected) + assert_equal_with_units(actual, expected) @pytest.mark.parametrize( "func", (pytest.param(str, id="str"), pytest.param(repr, id="repr")) @@ -2749,12 +3698,7 @@ def test_repr(self, func, variant, dtype): reason="np.median does not work with dataset yet" ), ), - pytest.param( - function("sum"), - marks=pytest.mark.xfail( - reason="np.result_type not implemented by pint" - ), - ), + function("sum"), pytest.param( function("prod"), marks=pytest.mark.xfail(reason="not implemented by pint"), @@ -2764,9 +3708,7 @@ def test_repr(self, func, variant, dtype): function("cumsum"), pytest.param( function("cumprod"), - marks=pytest.mark.xfail( - reason="pint does not support cumprod on non-dimensionless yet" - ), + marks=pytest.mark.xfail(reason="fails within xarray"), ), pytest.param( method("all"), marks=pytest.mark.xfail(reason="not implemented by pint") @@ -2780,12 +3722,7 @@ def test_repr(self, func, variant, dtype): method("min"), method("mean"), method("median"), - pytest.param( - method("sum"), - marks=pytest.mark.xfail( - reason="np.result_type not implemented by pint" - ), - ), + method("sum"), pytest.param( method("prod"), marks=pytest.mark.xfail(reason="not implemented by pint"), @@ -2794,17 +3731,20 @@ def test_repr(self, func, variant, dtype): method("var"), method("cumsum"), pytest.param( - method("cumprod"), - marks=pytest.mark.xfail( - reason="pint does not support cumprod on non-dimensionless yet" - ), + method("cumprod"), marks=pytest.mark.xfail(reason="fails within xarray") ), ), ids=repr, ) def test_aggregation(self, func, dtype): - unit_a = unit_registry.Pa - unit_b = unit_registry.kg / unit_registry.m ** 3 + unit_a = ( + unit_registry.Pa if func.name != "cumprod" else unit_registry.dimensionless + ) + unit_b = ( + unit_registry.kg / unit_registry.m ** 3 + if func.name != "cumprod" + else unit_registry.dimensionless + ) a = xr.DataArray(data=np.linspace(0, 1, 10).astype(dtype) * unit_a, dims="x") b = xr.DataArray(data=np.linspace(-1, 0, 10).astype(dtype) * unit_b, dims="x") x = xr.DataArray(data=np.arange(10).astype(dtype) * unit_registry.m, dims="x") @@ -2814,13 +3754,16 @@ def test_aggregation(self, func, dtype): ds = xr.Dataset(data_vars={"a": a, "b": b}, coords={"x": x, "y": y}) - result = func(ds) + actual = func(ds) expected = attach_units( func(strip_units(ds)), - {"a": array_extract_units(func(a)), "b": array_extract_units(func(b))}, + { + "a": extract_units(func(a)).get(None), + "b": extract_units(func(b)).get(None), + }, ) - assert_equal_with_units(result, expected) + assert_equal_with_units(actual, expected) @pytest.mark.parametrize("property", ("imag", "real")) def test_numpy_properties(self, property, dtype): @@ -2840,10 +3783,10 @@ def test_numpy_properties(self, property, dtype): ) units = extract_units(ds) - result = getattr(ds, property) + actual = getattr(ds, property) expected = attach_units(getattr(strip_units(ds), property), units) - assert_equal_with_units(result, expected) + assert_equal_with_units(actual, expected) @pytest.mark.parametrize( "func", @@ -2853,10 +3796,6 @@ def test_numpy_properties(self, property, dtype): method("argsort"), method("conjugate"), method("round"), - pytest.param( - method("rank", dim="x"), - marks=pytest.mark.xfail(reason="pint does not implement rank yet"), - ), ), ids=repr, ) @@ -2882,10 +3821,10 @@ def test_numpy_methods(self, func, dtype): "y": unit_registry.s, } - result = func(ds) + actual = func(ds) expected = attach_units(func(strip_units(ds)), units) - assert_equal_with_units(result, expected) + assert_equal_with_units(actual, expected) @pytest.mark.parametrize("func", (method("clip", min=3, max=8),), ids=repr) @pytest.mark.parametrize( @@ -2914,37 +3853,26 @@ def test_numpy_methods_with_args(self, func, unit, error, dtype): ) units = extract_units(ds) - def strip(value): - return ( - value.magnitude if isinstance(value, unit_registry.Quantity) else value - ) - - def convert(value, to): - if isinstance(value, unit_registry.Quantity) and value.check(to): - return value.to(to) - - return value - - scalar_types = (int, float) kwargs = { - key: (value * unit if isinstance(value, scalar_types) else value) + key: (value * unit if isinstance(value, (int, float)) else value) for key, value in func.kwargs.items() } - stripped_kwargs = { - key: strip(convert(value, data_unit)) for key, value in kwargs.items() - } - if error is not None: with pytest.raises(error): func(ds, **kwargs) return - result = func(ds, **kwargs) + stripped_kwargs = { + key: strip_units(convert_units(value, {None: data_unit})) + for key, value in kwargs.items() + } + + actual = func(ds, **kwargs) expected = attach_units(func(strip_units(ds), **stripped_kwargs), units) - assert_equal_with_units(result, expected) + assert_equal_with_units(actual, expected) @pytest.mark.parametrize( "func", (method("isnull"), method("notnull"), method("count")), ids=repr @@ -2987,9 +3915,9 @@ def test_missing_value_detection(self, func, dtype): ) expected = func(strip_units(ds)) - result = func(ds) + actual = func(ds) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.xfail(reason="ffill and bfill lose the unit") @pytest.mark.parametrize("func", (method("ffill"), method("bfill")), ids=repr) @@ -3017,40 +3945,35 @@ def test_missing_value_filling(self, func, dtype): func(strip_units(ds), dim="x"), {"a": unit_registry.degK, "b": unit_registry.Pa}, ) - result = func(ds, dim="x") + actual = func(ds, dim="x") - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) - @pytest.mark.xfail(reason="fillna drops the unit") @pytest.mark.parametrize( "unit,error", ( - pytest.param( - 1, - DimensionalityError, - id="no_unit", - marks=pytest.mark.xfail(reason="blocked by the failing `where`"), - ), + pytest.param(1, DimensionalityError, id="no_unit"), pytest.param( unit_registry.dimensionless, DimensionalityError, id="dimensionless" ), pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), - pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param( + unit_registry.cm, + None, + id="compatible_unit", + marks=pytest.mark.xfail( + reason="where converts the array, not the fill value" + ), + ), pytest.param(unit_registry.m, None, id="identical_unit"), ), ) @pytest.mark.parametrize( "fill_value", ( - pytest.param( - -1, - id="python scalar", - marks=pytest.mark.xfail( - reason="python scalar cannot be converted using astype()" - ), - ), - pytest.param(np.array(-1), id="numpy scalar"), - pytest.param(np.array([-1]), id="numpy array"), + pytest.param(-1, id="python_scalar"), + pytest.param(np.array(-1), id="numpy_scalar"), + pytest.param(np.array([-1]), id="numpy_array"), ), ) def test_fillna(self, fill_value, unit, error, dtype): @@ -3075,13 +3998,17 @@ def test_fillna(self, fill_value, unit, error, dtype): return - result = ds.fillna(value=fill_value * unit) + actual = ds.fillna(value=fill_value * unit) expected = attach_units( - strip_units(ds).fillna(value=fill_value), + strip_units(ds).fillna( + value=strip_units( + convert_units(fill_value * unit, {None: unit_registry.m}) + ) + ), {"a": unit_registry.m, "b": unit_registry.m}, ) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) def test_dropna(self, dtype): array1 = ( @@ -3105,11 +4032,10 @@ def test_dropna(self, dtype): strip_units(ds).dropna(dim="x"), {"a": unit_registry.degK, "b": unit_registry.Pa}, ) - result = ds.dropna(dim="x") + actual = ds.dropna(dim="x") - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) - @pytest.mark.xfail(reason="pint does not implement `numpy.isin`") @pytest.mark.parametrize( "unit", ( @@ -3154,36 +4080,12 @@ def test_isin(self, unit, dtype): ): expected.a[:] = False expected.b[:] = False - result = ds.isin(values) + actual = ds.isin(values) - assert_equal_with_units(result, expected) + assert_equal_with_units(actual, expected) @pytest.mark.parametrize( - "variant", - ( - pytest.param( - "masking", - marks=pytest.mark.xfail( - reason="np.result_type not implemented by quantity" - ), - ), - pytest.param( - "replacing_scalar", - marks=pytest.mark.xfail( - reason="python scalar not convertible using astype" - ), - ), - pytest.param( - "replacing_array", - marks=pytest.mark.xfail( - reason="replacing using an array drops the units" - ), - ), - pytest.param( - "dropping", - marks=pytest.mark.xfail(reason="nan not compatible with quantity"), - ), - ), + "variant", ("masking", "replacing_scalar", "replacing_array", "dropping") ) @pytest.mark.parametrize( "unit,error", @@ -3198,9 +4100,6 @@ def test_isin(self, unit, dtype): ), ) def test_where(self, variant, unit, error, dtype): - def _strip_units(mapping): - return {key: array_strip_units(value) for key, value in mapping.items()} - original_unit = unit_registry.m array1 = np.linspace(0, 1, 10).astype(dtype) * original_unit array2 = np.linspace(-1, 0, 10).astype(dtype) * original_unit @@ -3222,21 +4121,24 @@ def _strip_units(mapping): "dropping": {"cond": condition, "drop": True}, } kwargs = variant_kwargs.get(variant) - kwargs_without_units = _strip_units(kwargs) - if variant not in ("masking", "dropping") and error is not None: with pytest.raises(error): ds.where(**kwargs) return + kwargs_without_units = { + key: strip_units(convert_units(value, {None: original_unit})) + for key, value in kwargs.items() + } + expected = attach_units( strip_units(ds).where(**kwargs_without_units), {"a": original_unit, "b": original_unit}, ) - result = ds.where(**kwargs) + actual = ds.where(**kwargs) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.xfail(reason="interpolate strips units") def test_interpolate_na(self, dtype): @@ -3261,11 +4163,11 @@ def test_interpolate_na(self, dtype): strip_units(ds).interpolate_na(dim="x"), {"a": unit_registry.degK, "b": unit_registry.Pa}, ) - result = ds.interpolate_na(dim="x") + actual = ds.interpolate_na(dim="x") - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) - @pytest.mark.xfail(reason="uses Dataset.where, which currently fails") + @pytest.mark.xfail(reason="wrong argument order for `where`") @pytest.mark.parametrize( "unit,error", ( @@ -3281,11 +4183,11 @@ def test_interpolate_na(self, dtype): def test_combine_first(self, unit, error, dtype): array1 = ( np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) - * unit_registry.degK + * unit_registry.m ) array2 = ( np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) - * unit_registry.Pa + * unit_registry.m ) x = np.arange(len(array1)) ds = xr.Dataset( @@ -3312,12 +4214,16 @@ def test_combine_first(self, unit, error, dtype): return expected = attach_units( - strip_units(ds).combine_first(strip_units(other)), + strip_units(ds).combine_first( + strip_units( + convert_units(other, {"a": unit_registry.m, "b": unit_registry.m}) + ) + ), {"a": unit_registry.m, "b": unit_registry.m}, ) - result = ds.combine_first(other) + actual = ds.combine_first(other) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( "unit", @@ -3325,11 +4231,7 @@ def test_combine_first(self, unit, error, dtype): pytest.param(1, id="no_unit"), pytest.param(unit_registry.dimensionless, id="dimensionless"), pytest.param(unit_registry.s, id="incompatible_unit"), - pytest.param( - unit_registry.cm, - id="compatible_unit", - marks=pytest.mark.xfail(reason="identical does not check units yet"), - ), + pytest.param(unit_registry.cm, id="compatible_unit"), pytest.param(unit_registry.m, id="identical_unit"), ), ) @@ -3345,6 +4247,13 @@ def test_combine_first(self, unit, error, dtype): ) @pytest.mark.parametrize("func", (method("equals"), method("identical")), ids=repr) def test_comparisons(self, func, variation, unit, dtype): + def is_compatible(a, b): + a = a if a is not None else 1 + b = b if b is not None else 1 + quantity = np.arange(5) * a + + return a == b or quantity.check(b) + array1 = np.linspace(0, 5, 10).astype(dtype) array2 = np.linspace(-5, 0, 10).astype(dtype) @@ -3356,11 +4265,7 @@ def test_comparisons(self, func, variation, unit, dtype): x = coord * original_unit y = coord * original_unit - units = { - "data": (unit, original_unit, original_unit), - "dims": (original_unit, unit, original_unit), - "coords": (original_unit, original_unit, unit), - } + units = {"data": (unit, 1, 1), "dims": (1, unit, 1), "coords": (1, 1, unit)} data_unit, dim_unit, coord_unit = units.get(variation) ds = xr.Dataset( @@ -3371,36 +4276,27 @@ def test_comparisons(self, func, variation, unit, dtype): coords={"x": x, "y": ("x", y)}, ) - other = attach_units( - strip_units(ds), - { - "a": (data_unit, original_unit if quantity1.check(data_unit) else None), - "b": (data_unit, original_unit if quantity2.check(data_unit) else None), - "x": (dim_unit, original_unit if x.check(dim_unit) else None), - "y": (coord_unit, original_unit if y.check(coord_unit) else None), - }, - ) + other_units = { + "a": data_unit if quantity1.check(data_unit) else None, + "b": data_unit if quantity2.check(data_unit) else None, + "x": dim_unit if x.check(dim_unit) else None, + "y": coord_unit if y.check(coord_unit) else None, + } + other = attach_units(strip_units(convert_units(ds, other_units)), other_units) - # TODO: test dim coord once indexes leave units intact - # also, express this in terms of calls on the raw data array - # and then check the units - equal_arrays = ( - np.all(ds.a.data == other.a.data) - and np.all(ds.b.data == other.b.data) - and (np.all(x == other.x.data) or True) # dims can't be checked yet - and np.all(y == other.y.data) - ) - equal_units = ( - data_unit == original_unit - and coord_unit == original_unit - and dim_unit == original_unit - ) - expected = equal_arrays and (func.name != "identical" or equal_units) - result = func(ds, other) + units = extract_units(ds) + other_units = extract_units(other) + + equal_ds = all( + is_compatible(units[name], other_units[name]) for name in units.keys() + ) and (strip_units(ds).equals(strip_units(convert_units(other, units)))) + equal_units = units == other_units + expected = equal_ds and (func.name != "identical" or equal_units) + + actual = func(ds, other) - assert expected == result + assert expected == actual - @pytest.mark.xfail(reason="blocked by `where`") @pytest.mark.parametrize( "unit", ( @@ -3430,9 +4326,9 @@ def test_broadcast_like(self, unit, dtype): expected = attach_units( strip_units(ds1).broadcast_like(strip_units(ds2)), extract_units(ds1) ) - result = ds1.broadcast_like(ds2) + actual = ds1.broadcast_like(ds2) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( "unit", @@ -3446,38 +4342,34 @@ def test_broadcast_like(self, unit, dtype): ) def test_broadcast_equals(self, unit, dtype): left_array1 = np.ones(shape=(2, 3), dtype=dtype) * unit_registry.m - left_array2 = np.zeros(shape=(2, 6), dtype=dtype) * unit_registry.m + left_array2 = np.zeros(shape=(3, 6), dtype=dtype) * unit_registry.m - right_array1 = array_attach_units( - np.ones(shape=(2,), dtype=dtype), - unit, - convert_from=unit_registry.m if left_array1.check(unit) else None, - ) - right_array2 = array_attach_units( - np.ones(shape=(2,), dtype=dtype), - unit, - convert_from=unit_registry.m if left_array2.check(unit) else None, - ) + right_array1 = np.ones(shape=(2,)) * unit + right_array2 = np.ones(shape=(3,)) * unit left = xr.Dataset( data_vars={ "a": xr.DataArray(data=left_array1, dims=("x", "y")), - "b": xr.DataArray(data=left_array2, dims=("x", "z")), + "b": xr.DataArray(data=left_array2, dims=("y", "z")), } ) right = xr.Dataset( data_vars={ "a": xr.DataArray(data=right_array1, dims="x"), - "b": xr.DataArray(data=right_array2, dims="x"), + "b": xr.DataArray(data=right_array2, dims="y"), } ) - expected = np.all(left_array1 == right_array1[:, None]) and np.all( - left_array2 == right_array2[:, None] - ) - result = left.broadcast_equals(right) + units = { + **extract_units(left), + **({} if left_array1.check(unit) else {"a": None, "b": None}), + } + expected = strip_units(left).broadcast_equals( + strip_units(convert_units(right, units)) + ) & left_array1.check(unit) + actual = left.broadcast_equals(right) - assert expected == result + assert expected == actual @pytest.mark.parametrize( "func", @@ -3510,11 +4402,11 @@ def test_stacking_stacked(self, func, dtype): expected = attach_units( func(strip_units(stacked)), {"a": unit_registry.m, "b": unit_registry.m} ) - result = func(stacked) + actual = func(stacked) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) - @pytest.mark.xfail(reason="tries to subscript scalar quantities") + @pytest.mark.xfail(reason="does not work with quantities yet") def test_to_stacked_array(self, dtype): labels = np.arange(5).astype(dtype) * unit_registry.s arrays = {name: np.linspace(0, 1, 10) * unit_registry.m for name in labels} @@ -3528,13 +4420,13 @@ def test_to_stacked_array(self, dtype): func = method("to_stacked_array", "z", variable_dim="y", sample_dims=["x"]) - result = func(ds).rename(None) + actual = func(ds).rename(None) expected = attach_units( func(strip_units(ds)).rename(None), {None: unit_registry.m, "y": unit_registry.s}, ) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( "func", @@ -3543,12 +4435,10 @@ def test_to_stacked_array(self, dtype): method("stack", a=("x", "y")), method("set_index", x="x2"), pytest.param( - method("shift", x=2), marks=pytest.mark.xfail(reason="sets all to nan") - ), - pytest.param( - method("roll", x=2, roll_coords=False), - marks=pytest.mark.xfail(reason="strips units"), + method("shift", x=2), + marks=pytest.mark.xfail(reason="tries to concatenate nan arrays"), ), + method("roll", x=2, roll_coords=False), method("sortby", "x2"), ), ids=repr, @@ -3581,9 +4471,9 @@ def test_stacking_reordering(self, func, dtype): expected = attach_units( func(strip_units(ds)), {"a": unit_registry.Pa, "b": unit_registry.degK} ) - result = func(ds) + actual = func(ds) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.xfail(reason="indexes strip units") @pytest.mark.parametrize( @@ -3610,35 +4500,33 @@ def test_isel(self, indices, dtype): strip_units(ds).isel(x=indices), {"a": unit_registry.s, "b": unit_registry.Pa, "x": unit_registry.m}, ) - result = ds.isel(x=indices) + actual = ds.isel(x=indices) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) - @pytest.mark.xfail( - reason="xarray does not support duck arrays in dimension coordinates" - ) + @pytest.mark.xfail(reason="indexes don't support units") @pytest.mark.parametrize( - "values", + "raw_values", ( - pytest.param(12, id="single_value"), + pytest.param(10, id="single_value"), pytest.param([10, 5, 13], id="list_of_values"), pytest.param(np.array([9, 3, 7, 12]), id="array_of_values"), ), ) @pytest.mark.parametrize( - "units,error", + "unit,error", ( pytest.param(1, KeyError, id="no_units"), pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"), pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"), - pytest.param(unit_registry.ms, KeyError, id="compatible_unit"), - pytest.param(unit_registry.s, None, id="same_unit"), + pytest.param(unit_registry.dm, KeyError, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), ), ) - def test_sel(self, values, units, error, dtype): + def test_sel(self, raw_values, unit, error, dtype): array1 = np.linspace(5, 10, 20).astype(dtype) * unit_registry.degK array2 = np.linspace(0, 5, 20).astype(dtype) * unit_registry.Pa - x = np.arange(len(array1)) * unit_registry.s + x = np.arange(len(array1)) * unit_registry.m ds = xr.Dataset( data_vars={ @@ -3648,46 +4536,46 @@ def test_sel(self, values, units, error, dtype): coords={"x": x}, ) - values_with_units = values * units + values = raw_values * unit - if error is not None: + if error is not None and not ( + isinstance(raw_values, (int, float)) and x.check(unit) + ): with pytest.raises(error): - ds.sel(x=values_with_units) + ds.sel(x=values) return expected = attach_units( - strip_units(ds).sel(x=values), - {"a": unit_registry.degK, "b": unit_registry.Pa, "x": unit_registry.s}, + strip_units(ds).sel(x=strip_units(convert_units(values, {None: x.units}))), + {"a": array1.units, "b": array2.units, "x": x.units}, ) - result = ds.sel(x=values_with_units) - assert_equal_with_units(expected, result) + actual = ds.sel(x=values) + assert_equal_with_units(expected, actual) - @pytest.mark.xfail( - reason="xarray does not support duck arrays in dimension coordinates" - ) + @pytest.mark.xfail(reason="indexes don't support units") @pytest.mark.parametrize( - "values", + "raw_values", ( - pytest.param(12, id="single value"), - pytest.param([10, 5, 13], id="list of multiple values"), - pytest.param(np.array([9, 3, 7, 12]), id="array of multiple values"), + pytest.param(10, id="single_value"), + pytest.param([10, 5, 13], id="list_of_values"), + pytest.param(np.array([9, 3, 7, 12]), id="array_of_values"), ), ) @pytest.mark.parametrize( - "units,error", + "unit,error", ( pytest.param(1, KeyError, id="no_units"), pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"), pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"), - pytest.param(unit_registry.ms, KeyError, id="compatible_unit"), - pytest.param(unit_registry.s, None, id="same_unit"), + pytest.param(unit_registry.dm, KeyError, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), ), ) - def test_loc(self, values, units, error, dtype): + def test_drop_sel(self, raw_values, unit, error, dtype): array1 = np.linspace(5, 10, 20).astype(dtype) * unit_registry.degK array2 = np.linspace(0, 5, 20).astype(dtype) * unit_registry.Pa - x = np.arange(len(array1)) * unit_registry.s + x = np.arange(len(array1)) * unit_registry.m ds = xr.Dataset( data_vars={ @@ -3697,36 +4585,76 @@ def test_loc(self, values, units, error, dtype): coords={"x": x}, ) - values_with_units = values * units + values = raw_values * unit - if error is not None: + if error is not None and not ( + isinstance(raw_values, (int, float)) and x.check(unit) + ): with pytest.raises(error): - ds.loc[{"x": values_with_units}] + ds.drop_sel(x=values) return expected = attach_units( - strip_units(ds).loc[{"x": values}], - {"a": unit_registry.degK, "b": unit_registry.Pa, "x": unit_registry.s}, + strip_units(ds).drop_sel( + x=strip_units(convert_units(values, {None: x.units})) + ), + extract_units(ds), ) - result = ds.loc[{"x": values_with_units}] - assert_equal_with_units(expected, result) + actual = ds.drop_sel(x=values) + assert_equal_with_units(expected, actual) - @pytest.mark.xfail( - reason="indexes strip units and head / tail / thin only support integers" + @pytest.mark.xfail(reason="indexes don't support units") + @pytest.mark.parametrize( + "raw_values", + ( + pytest.param(10, id="single_value"), + pytest.param([10, 5, 13], id="list_of_values"), + pytest.param(np.array([9, 3, 7, 12]), id="array_of_values"), + ), ) @pytest.mark.parametrize( "unit,error", ( - pytest.param(1, DimensionalityError, id="no_unit"), - pytest.param( - unit_registry.dimensionless, DimensionalityError, id="dimensionless" - ), - pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), - pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(1, KeyError, id="no_units"), + pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"), + pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"), + pytest.param(unit_registry.dm, KeyError, id="compatible_unit"), pytest.param(unit_registry.m, None, id="identical_unit"), ), ) + def test_loc(self, raw_values, unit, error, dtype): + array1 = np.linspace(5, 10, 20).astype(dtype) * unit_registry.degK + array2 = np.linspace(0, 5, 20).astype(dtype) * unit_registry.Pa + x = np.arange(len(array1)) * unit_registry.m + + ds = xr.Dataset( + data_vars={ + "a": xr.DataArray(data=array1, dims="x"), + "b": xr.DataArray(data=array2, dims="x"), + }, + coords={"x": x}, + ) + + values = raw_values * unit + + if error is not None and not ( + isinstance(raw_values, (int, float)) and x.check(unit) + ): + with pytest.raises(error): + ds.loc[{"x": values}] + + return + + expected = attach_units( + strip_units(ds).loc[ + {"x": strip_units(convert_units(values, {None: x.units}))} + ], + {"a": array1.units, "b": array2.units, "x": x.units}, + ) + actual = ds.loc[{"x": values}] + assert_equal_with_units(expected, actual) + @pytest.mark.parametrize( "func", ( @@ -3736,7 +4664,7 @@ def test_loc(self, values, units, error, dtype): ), ids=repr, ) - def test_head_tail_thin(self, func, unit, error, dtype): + def test_head_tail_thin(self, func, dtype): array1 = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK array2 = np.linspace(1, 2, 10 * 8).reshape(10, 8) * unit_registry.Pa @@ -3754,18 +4682,10 @@ def test_head_tail_thin(self, func, unit, error, dtype): coords=coords, ) - kwargs = {name: value * unit for name, value in func.kwargs.items()} - - if error is not None: - with pytest.raises(error): - func(ds, **kwargs) - - return - expected = attach_units(func(strip_units(ds)), extract_units(ds)) - result = func(ds, **kwargs) + actual = func(ds) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( "shape", @@ -3802,15 +4722,15 @@ def test_squeeze(self, shape, dtype): expected = attach_units(strip_units(ds).squeeze(), units) - result = ds.squeeze() - assert_equal_with_units(result, expected) + actual = ds.squeeze() + assert_equal_with_units(actual, expected) # try squeezing the dimensions separately names = tuple(dim for dim, coord in coords.items() if len(coord) == 1) for name in names: expected = attach_units(strip_units(ds).squeeze(dim=name), units) - result = ds.squeeze(dim=name) - assert_equal_with_units(result, expected) + actual = ds.squeeze(dim=name) + assert_equal_with_units(actual, expected) @pytest.mark.xfail(reason="ignores units") @pytest.mark.parametrize( @@ -3851,12 +4771,14 @@ def test_interp(self, unit, error): return + units = extract_units(ds) expected = attach_units( - strip_units(ds).interp(x=strip_units(new_coords)), extract_units(ds) + strip_units(ds).interp(x=strip_units(convert_units(new_coords, units))), + units, ) - result = ds.interp(x=new_coords) + actual = ds.interp(x=new_coords) - assert_equal_with_units(result, expected) + assert_equal_with_units(actual, expected) @pytest.mark.xfail(reason="ignores units") @pytest.mark.parametrize( @@ -3911,16 +4833,15 @@ def test_interp_like(self, unit, error, dtype): return + units = extract_units(ds) expected = attach_units( - strip_units(ds).interp_like(strip_units(other)), extract_units(ds) + strip_units(ds).interp_like(strip_units(convert_units(other, units))), units ) - result = ds.interp_like(other) + actual = ds.interp_like(other) - assert_equal_with_units(result, expected) + assert_equal_with_units(actual, expected) - @pytest.mark.xfail( - reason="pint does not implement np.result_type in __array_function__ yet" - ) + @pytest.mark.xfail(reason="indexes don't support units") @pytest.mark.parametrize( "unit,error", ( @@ -3933,9 +4854,13 @@ def test_interp_like(self, unit, error, dtype): pytest.param(unit_registry.m, None, id="identical_unit"), ), ) - def test_reindex(self, unit, error): - array1 = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK - array2 = np.linspace(1, 2, 10 * 8).reshape(10, 8) * unit_registry.Pa + def test_reindex(self, unit, error, dtype): + array1 = ( + np.linspace(1, 2, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK + ) + array2 = ( + np.linspace(1, 2, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa + ) coords = { "x": np.arange(10) * unit_registry.m, @@ -3955,20 +4880,21 @@ def test_reindex(self, unit, error): if error is not None: with pytest.raises(error): - ds.interp(x=new_coords) + ds.reindex(x=new_coords) return expected = attach_units( - strip_units(ds).reindex(x=strip_units(new_coords)), extract_units(ds) + strip_units(ds).reindex( + x=strip_units(convert_units(new_coords, {None: coords["x"].units})) + ), + extract_units(ds), ) - result = ds.reindex(x=new_coords) + actual = ds.reindex(x=new_coords) - assert_equal_with_units(result, expected) + assert_equal_with_units(actual, expected) - @pytest.mark.xfail( - reason="pint does not implement np.result_type in __array_function__ yet" - ) + @pytest.mark.xfail(reason="indexes don't support units") @pytest.mark.parametrize( "unit,error", ( @@ -4021,12 +4947,14 @@ def test_reindex_like(self, unit, error, dtype): return + units = extract_units(ds) expected = attach_units( - strip_units(ds).reindex_like(strip_units(other)), extract_units(ds) + strip_units(ds).reindex_like(strip_units(convert_units(other, units))), + units, ) - result = ds.reindex_like(other) + actual = ds.reindex_like(other) - assert_equal_with_units(result, expected) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( "func", @@ -4036,18 +4964,10 @@ def test_reindex_like(self, unit, error, dtype): method("integrate", coord="x"), pytest.param( method("quantile", q=[0.25, 0.75]), - marks=pytest.mark.xfail( - reason="pint does not implement nanpercentile yet" - ), - ), - pytest.param( - method("reduce", func=np.sum, dim="x"), - marks=pytest.mark.xfail(reason="strips units"), - ), - pytest.param( - method("map", np.fabs), - marks=pytest.mark.xfail(reason="fabs strips units"), + marks=pytest.mark.xfail(reason="nanquantile not implemented"), ), + method("reduce", func=np.sum, dim="x"), + method("map", np.fabs), ), ids=repr, ) @@ -4073,27 +4993,22 @@ def test_computation(self, func, dtype): units = extract_units(ds) expected = attach_units(func(strip_units(ds)), units) - result = func(ds) + actual = func(ds) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( "func", ( - pytest.param( - method("groupby", "x"), marks=pytest.mark.xfail(reason="strips units") - ), - pytest.param( - method("groupby_bins", "x", bins=4), - marks=pytest.mark.xfail(reason="strips units"), - ), + method("groupby", "x"), + method("groupby_bins", "x", bins=4), method("coarsen", x=2), pytest.param( method("rolling", x=3), marks=pytest.mark.xfail(reason="strips units") ), pytest.param( method("rolling_exp", x=3), - marks=pytest.mark.xfail(reason="strips units"), + marks=pytest.mark.xfail(reason="uses numbagg which strips units"), ), ), ids=repr, @@ -4122,11 +5037,10 @@ def test_computation_objects(self, func, dtype): args = [] if func.name != "groupby" else ["y"] reduce_func = method("mean", *args) expected = attach_units(reduce_func(func(strip_units(ds))), units) - result = reduce_func(func(ds)) + actual = reduce_func(func(ds)) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) - @pytest.mark.xfail(reason="strips units") def test_resample(self, dtype): array1 = ( np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK @@ -4150,28 +5064,20 @@ def test_resample(self, dtype): func = method("resample", time="6m") expected = attach_units(func(strip_units(ds)).mean(), units) - result = func(ds).mean() + actual = func(ds).mean() - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( "func", ( - pytest.param( - method("assign", c=lambda ds: 10 * ds.b), - marks=pytest.mark.xfail(reason="strips units"), - ), - pytest.param( - method("assign_coords", v=("x", np.arange(10) * unit_registry.s)), - marks=pytest.mark.xfail(reason="strips units"), - ), - pytest.param(method("first")), - pytest.param(method("last")), + method("assign", c=lambda ds: 10 * ds.b), + method("assign_coords", v=("x", np.arange(10) * unit_registry.s)), + method("first"), + method("last"), pytest.param( method("quantile", q=[0.25, 0.5, 0.75], dim="x"), - marks=pytest.mark.xfail( - reason="dataset groupby does not implement quantile" - ), + marks=pytest.mark.xfail(reason="nanquantile not implemented"), ), ), ids=repr, @@ -4204,9 +5110,9 @@ def test_grouped_operations(self, func, dtype): expected = attach_units( func(strip_units(ds).groupby("y"), **stripped_kwargs), units ) - result = func(ds.groupby("y")) + actual = func(ds.groupby("y")) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) @pytest.mark.parametrize( "func", @@ -4220,7 +5126,7 @@ def test_grouped_operations(self, func, dtype): method("rename_dims", x="offset_x"), method("swap_dims", {"x": "x2"}), method("expand_dims", v=np.linspace(10, 20, 12) * unit_registry.s, axis=1), - method("drop_sel", labels="x"), + method("drop_vars", "x"), method("drop_dims", "z"), method("set_coords", names="c"), method("reset_coords", names="x2"), @@ -4252,26 +5158,25 @@ def test_content_manipulation(self, func, dtype): }, coords={"x": x, "y": y, "z": z, "x2": ("x", x2)}, ) - units = extract_units(ds) - units.update( - { + units = { + **extract_units(ds), + **{ "y2": unit_registry.mm, "x_mm": unit_registry.mm, "offset_x": unit_registry.m, "d": unit_registry.Pa, "temperature": unit_registry.degK, - } - ) + }, + } stripped_kwargs = { key: strip_units(value) for key, value in func.kwargs.items() } expected = attach_units(func(strip_units(ds), **stripped_kwargs), units) - result = func(ds) + actual = func(ds) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) - @pytest.mark.xfail(reason="blocked by reindex") @pytest.mark.parametrize( "unit,error", ( @@ -4284,7 +5189,16 @@ def test_content_manipulation(self, func, dtype): pytest.param(unit_registry.m, None, id="identical_unit"), ), ) - @pytest.mark.parametrize("variant", ("data", "dims", "coords")) + @pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param( + "dims", marks=pytest.mark.xfail(reason="indexes don't support units") + ), + "coords", + ), + ) def test_merge(self, variant, unit, error, dtype): original_data_unit = unit_registry.m original_dim_unit = unit_registry.m @@ -4325,6 +5239,6 @@ def test_merge(self, variant, unit, error, dtype): converted = convert_units(right, units) expected = attach_units(strip_units(left).merge(strip_units(converted)), units) - result = left.merge(right) + actual = left.merge(right) - assert_equal_with_units(expected, result) + assert_equal_with_units(expected, actual) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 1d83e16a5bd..62fde920b1e 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -432,7 +432,7 @@ def test_concat(self): assert_identical( Variable(["b", "a"], np.array([x, y])), Variable.concat((v, w), "b") ) - with raises_regex(ValueError, "inconsistent dimensions"): + with raises_regex(ValueError, "Variable has dimensions"): Variable.concat([v, Variable(["c"], y)], "b") # test indexers actual = Variable.concat( @@ -451,16 +451,12 @@ def test_concat(self): Variable.concat([v[:, 0], v[:, 1:]], "x") def test_concat_attrs(self): - # different or conflicting attributes should be removed + # always keep attrs from first variable v = self.cls("a", np.arange(5), {"foo": "bar"}) w = self.cls("a", np.ones(5)) expected = self.cls( "a", np.concatenate([np.arange(5), np.ones(5)]) ).to_base_variable() - assert_identical(expected, Variable.concat([v, w], "a")) - w.attrs["foo"] = 2 - assert_identical(expected, Variable.concat([v, w], "a")) - w.attrs["foo"] = "bar" expected.attrs["foo"] = "bar" assert_identical(expected, Variable.concat([v, w], "a")) @@ -1542,6 +1538,14 @@ def test_quantile_chunked_dim_error(self): with raises_regex(ValueError, "dimension 'x'"): v.quantile(0.5, dim="x") + @pytest.mark.parametrize("q", [-0.1, 1.1, [2], [0.25, 2]]) + def test_quantile_out_of_bounds(self, q): + v = Variable(["x", "y"], self.d) + + # escape special characters + with raises_regex(ValueError, r"Quantiles must be in the range \[0, 1\]"): + v.quantile(q, dim="x") + @requires_dask @requires_bottleneck def test_rank_dask_raises(self): diff --git a/xarray/ufuncs.py b/xarray/ufuncs.py index ae2c5c574b6..8ab2b7cfe31 100644 --- a/xarray/ufuncs.py +++ b/xarray/ufuncs.py @@ -132,14 +132,68 @@ def _create_op(name): return func -__all__ = """logaddexp logaddexp2 conj exp log log2 log10 log1p expm1 sqrt - square sin cos tan arcsin arccos arctan arctan2 hypot sinh cosh - tanh arcsinh arccosh arctanh deg2rad rad2deg logical_and - logical_or logical_xor logical_not maximum minimum fmax fmin - isreal iscomplex isfinite isinf isnan signbit copysign nextafter - ldexp fmod floor ceil trunc degrees radians rint fix angle real - imag fabs sign frexp fmod - """.split() +__all__ = ( # noqa: F822 + "angle", + "arccos", + "arccosh", + "arcsin", + "arcsinh", + "arctan", + "arctan2", + "arctanh", + "ceil", + "conj", + "copysign", + "cos", + "cosh", + "deg2rad", + "degrees", + "exp", + "expm1", + "fabs", + "fix", + "floor", + "fmax", + "fmin", + "fmod", + "fmod", + "frexp", + "hypot", + "imag", + "iscomplex", + "isfinite", + "isinf", + "isnan", + "isreal", + "ldexp", + "log", + "log10", + "log1p", + "log2", + "logaddexp", + "logaddexp2", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "maximum", + "minimum", + "nextafter", + "rad2deg", + "radians", + "real", + "rint", + "sign", + "signbit", + "sin", + "sinh", + "sqrt", + "square", + "tan", + "tanh", + "trunc", +) + for name in __all__: globals()[name] = _create_op(name) diff --git a/xarray/util/print_versions.py b/xarray/util/print_versions.py index 0d6d147f0bb..6a0e62cc9dc 100755 --- a/xarray/util/print_versions.py +++ b/xarray/util/print_versions.py @@ -78,6 +78,13 @@ def netcdf_and_hdf5_versions(): def show_versions(file=sys.stdout): + """ print the versions of xarray and its dependencies + + Parameters + ---------- + file : file-like, optional + print to the given file-like object. Defaults to sys.stdout. + """ sys_info = get_sys_info() try: