diff --git a/.binder/environment.yml b/.binder/environment.yml index 6fd5829c5e6..6caea42df87 100644 --- a/.binder/environment.yml +++ b/.binder/environment.yml @@ -2,7 +2,7 @@ name: xarray-examples channels: - conda-forge dependencies: - - python=3.8 + - python=3.9 - boto3 - bottleneck - cartopy @@ -26,6 +26,7 @@ dependencies: - pandas - pint - pip + - pooch - pydap - pynio - rasterio diff --git a/.git_archival.txt b/.git_archival.txt new file mode 100644 index 00000000000..95cb3eea4e3 --- /dev/null +++ b/.git_archival.txt @@ -0,0 +1 @@ +ref-names: $Format:%D$ diff --git a/.gitattributes b/.gitattributes index a52f4ca283a..7a79ddd6b0b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,4 @@ # reduce the number of merge conflicts doc/whats-new.rst merge=union +# allow installing from git archives +.git_archival.txt export-subst diff --git a/.github/actions/detect-ci-trigger/action.yaml b/.github/actions/detect-ci-trigger/action.yaml deleted file mode 100644 index c255d0c57cc..00000000000 --- a/.github/actions/detect-ci-trigger/action.yaml +++ /dev/null @@ -1,29 +0,0 @@ -name: Detect CI Trigger -description: | - Detect a keyword used to control the CI in the subject line of a commit message. -inputs: - keyword: - description: | - The keyword to detect. - required: true -outputs: - trigger-found: - description: | - true if the keyword has been found in the subject line of the commit message - value: ${{ steps.detect-trigger.outputs.CI_TRIGGERED }} -runs: - using: "composite" - steps: - - name: detect trigger - id: detect-trigger - run: | - bash $GITHUB_ACTION_PATH/script.sh ${{ github.event_name }} ${{ inputs.keyword }} - shell: bash - - name: show detection result - run: | - echo "::group::final summary" - echo "commit message: ${{ steps.detect-trigger.outputs.COMMIT_MESSAGE }}" - echo "trigger keyword: ${{ inputs.keyword }}" - echo "trigger found: ${{ steps.detect-trigger.outputs.CI_TRIGGERED }}" - echo "::endgroup::" - shell: bash diff --git a/.github/actions/detect-ci-trigger/script.sh b/.github/actions/detect-ci-trigger/script.sh deleted file mode 100644 index c98175a5a08..00000000000 --- a/.github/actions/detect-ci-trigger/script.sh +++ /dev/null @@ -1,47 +0,0 @@ -#!/usr/bin/env bash -event_name="$1" -keyword="$2" - -echo "::group::fetch a sufficient number of commits" -echo "skipped" -# git log -n 5 2>&1 -# if [[ "$event_name" == "pull_request" ]]; then -# ref=$(git log -1 --format='%H') -# git -c protocol.version=2 fetch --deepen=2 --no-tags --prune --progress -q origin $ref 2>&1 -# git log FETCH_HEAD -# git checkout FETCH_HEAD -# else -# echo "nothing to do." -# fi -# git log -n 5 2>&1 -echo "::endgroup::" - -echo "::group::extracting the commit message" -echo "event name: $event_name" -if [[ "$event_name" == "pull_request" ]]; then - ref="HEAD^2" -else - ref="HEAD" -fi - -commit_message="$(git log -n 1 --pretty=format:%s "$ref")" - -if [[ $(echo $commit_message | wc -l) -le 1 ]]; then - echo "commit message: '$commit_message'" -else - echo -e "commit message:\n--- start ---\n$commit_message\n--- end ---" -fi -echo "::endgroup::" - -echo "::group::scanning for the keyword" -echo "searching for: '$keyword'" -if echo "$commit_message" | grep -qF "$keyword"; then - result="true" -else - result="false" -fi -echo "keyword detected: $result" -echo "::endgroup::" - -echo "::set-output name=COMMIT_MESSAGE::$commit_message" -echo "::set-output name=CI_TRIGGERED::$result" diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000000..bad6ba3f62a --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,7 @@ +version: 2 +updates: + - package-ecosystem: 'github-actions' + directory: '/' + schedule: + # Check for updates once a week + interval: 'weekly' diff --git a/.github/workflows/cancel-duplicate-runs.yaml b/.github/workflows/cancel-duplicate-runs.yaml new file mode 100644 index 00000000000..9f74360b034 --- /dev/null +++ b/.github/workflows/cancel-duplicate-runs.yaml @@ -0,0 +1,15 @@ +name: Cancel +on: + workflow_run: + workflows: ["CI", "CI Additional", "CI Upstream"] + types: + - requested +jobs: + cancel: + name: Cancel previous runs + runs-on: ubuntu-latest + if: github.repository == 'pydata/xarray' + steps: + - uses: styfle/cancel-workflow-action@0.9.1 + with: + workflow_id: ${{ github.event.workflow.id }} diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index fdc61f2f4f7..ed731b25f76 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -12,14 +12,16 @@ jobs: detect-ci-trigger: name: detect ci trigger runs-on: ubuntu-latest - if: github.event_name == 'push' || github.event_name == 'pull_request' + if: | + github.repository == 'pydata/xarray' + && (github.event_name == 'push' || github.event_name == 'pull_request') outputs: triggered: ${{ steps.detect-trigger.outputs.trigger-found }} steps: - uses: actions/checkout@v2 with: fetch-depth: 2 - - uses: ./.github/actions/detect-ci-trigger + - uses: xarray-contrib/ci-trigger@v1.1 id: detect-trigger with: keyword: "[skip-ci]" @@ -42,26 +44,16 @@ jobs: "py37-min-all-deps", "py37-min-nep18", "py38-all-but-dask", - "py38-backend-api-v2", "py38-flaky", ] steps: - - name: Cancel previous runs - uses: styfle/cancel-workflow-action@0.6.0 - with: - access_token: ${{ github.token }} - uses: actions/checkout@v2 with: fetch-depth: 0 # Fetch all history for all branches and tags. - name: Set environment variables run: | - if [[ ${{ matrix.env }} == "py38-backend-api-v2" ]] ; - then - echo "CONDA_ENV_FILE=ci/requirements/environment.yml" >> $GITHUB_ENV - echo "XARRAY_BACKEND_API=v2" >> $GITHUB_ENV - - elif [[ ${{ matrix.env }} == "py38-flaky" ]] ; + if [[ ${{ matrix.env }} == "py38-flaky" ]] ; then echo "CONDA_ENV_FILE=ci/requirements/environment.yml" >> $GITHUB_ENV echo "PYTEST_EXTRA_FLAGS=--run-flaky --run-network-tests" >> $GITHUB_ENV @@ -111,7 +103,7 @@ jobs: $PYTEST_EXTRA_FLAGS - name: Upload code coverage to Codecov - uses: codecov/codecov-action@v1 + uses: codecov/codecov-action@v2.0.2 with: file: ./coverage.xml flags: unittests,${{ matrix.env }} @@ -121,17 +113,12 @@ jobs: doctest: name: Doctests runs-on: "ubuntu-latest" - needs: detect-ci-trigger - if: needs.detect-ci-trigger.outputs.triggered == 'false' + if: github.repository == 'pydata/xarray' defaults: run: shell: bash -l {0} steps: - - name: Cancel previous runs - uses: styfle/cancel-workflow-action@0.6.0 - with: - access_token: ${{ github.token }} - uses: actions/checkout@v2 with: fetch-depth: 0 # Fetch all history for all branches and tags. @@ -169,10 +156,6 @@ jobs: shell: bash -l {0} steps: - - name: Cancel previous runs - uses: styfle/cancel-workflow-action@0.6.0 - with: - access_token: ${{ github.token }} - uses: actions/checkout@v2 with: fetch-depth: 0 # Fetch all history for all branches and tags. @@ -186,6 +169,6 @@ jobs: - name: minimum versions policy run: | - mamba install -y pyyaml conda + mamba install -y pyyaml conda python-dateutil python ci/min_deps_check.py ci/requirements/py37-bare-minimum.yml python ci/min_deps_check.py ci/requirements/py37-min-all-deps.yml diff --git a/.github/workflows/ci-pre-commit-autoupdate.yaml b/.github/workflows/ci-pre-commit-autoupdate.yaml new file mode 100644 index 00000000000..b10a541197e --- /dev/null +++ b/.github/workflows/ci-pre-commit-autoupdate.yaml @@ -0,0 +1,44 @@ +name: "pre-commit autoupdate CI" + +on: + schedule: + - cron: "0 0 * * 0" # every Sunday at 00:00 UTC + workflow_dispatch: + + +jobs: + autoupdate: + name: 'pre-commit autoupdate' + runs-on: ubuntu-latest + if: github.repository == 'pydata/xarray' + steps: + - name: checkout + uses: actions/checkout@v2 + - name: Cache pip and pre-commit + uses: actions/cache@v2 + with: + path: | + ~/.cache/pre-commit + ~/.cache/pip + key: ${{ runner.os }}-pre-commit-autoupdate + - name: setup python + uses: actions/setup-python@v2 + - name: upgrade pip + run: python -m pip install --upgrade pip + - name: install dependencies + run: python -m pip install --upgrade pre-commit pyyaml packaging + - name: version info + run: python -m pip list + - name: autoupdate + uses: technote-space/create-pr-action@837dbe469b39f08d416889369a52e2a993625c84 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + EXECUTE_COMMANDS: | + python -m pre_commit autoupdate + python -m pre_commit run --all-files + COMMIT_MESSAGE: 'pre-commit: autoupdate hook versions' + COMMIT_NAME: 'github-actions[bot]' + COMMIT_EMAIL: 'github-actions[bot]@users.noreply.github.com' + PR_TITLE: 'pre-commit: autoupdate hook versions' + PR_BRANCH_PREFIX: 'pre-commit/' + PR_BRANCH_NAME: 'autoupdate-${PR_ID}' diff --git a/.github/workflows/ci-pre-commit.yml b/.github/workflows/ci-pre-commit.yml index 1ab5642367e..4bc5bddfdbc 100644 --- a/.github/workflows/ci-pre-commit.yml +++ b/.github/workflows/ci-pre-commit.yml @@ -10,7 +10,8 @@ jobs: linting: name: "pre-commit hooks" runs-on: ubuntu-latest + if: github.repository == 'pydata/xarray' steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 - - uses: pre-commit/action@v2.0.0 + - uses: pre-commit/action@v2.0.3 diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 7d7326eb5c2..22a05eb1fc0 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -12,14 +12,16 @@ jobs: detect-ci-trigger: name: detect ci trigger runs-on: ubuntu-latest - if: github.event_name == 'push' || github.event_name == 'pull_request' + if: | + github.repository == 'pydata/xarray' + && (github.event_name == 'push' || github.event_name == 'pull_request') outputs: triggered: ${{ steps.detect-trigger.outputs.trigger-found }} steps: - uses: actions/checkout@v2 with: fetch-depth: 2 - - uses: ./.github/actions/detect-ci-trigger + - uses: xarray-contrib/ci-trigger@v1.1 id: detect-trigger with: keyword: "[skip-ci]" @@ -35,12 +37,9 @@ jobs: fail-fast: false matrix: os: ["ubuntu-latest", "macos-latest", "windows-latest"] - python-version: ["3.7", "3.8"] + # Bookend python versions + python-version: ["3.7", "3.9"] steps: - - name: Cancel previous runs - uses: styfle/cancel-workflow-action@0.6.0 - with: - access_token: ${{ github.token }} - uses: actions/checkout@v2 with: fetch-depth: 0 # Fetch all history for all branches and tags. @@ -59,8 +58,7 @@ jobs: uses: actions/cache@v2 with: path: ~/conda_pkgs_dir - key: - ${{ runner.os }}-conda-py${{ matrix.python-version }}-${{ + key: ${{ runner.os }}-conda-py${{ matrix.python-version }}-${{ hashFiles('ci/requirements/**.yml') }} - uses: conda-incubator/setup-miniconda@v2 with: @@ -89,16 +87,40 @@ jobs: run: | python -c "import xarray" - name: Run tests - run: | - python -m pytest -n 4 \ - --cov=xarray \ - --cov-report=xml + run: python -m pytest -n 4 + --cov=xarray + --cov-report=xml + --junitxml=pytest.xml + + - name: Upload test results + if: always() + uses: actions/upload-artifact@v2 + with: + name: Test results for ${{ runner.os }}-${{ matrix.python-version }} + path: pytest.xml - name: Upload code coverage to Codecov - uses: codecov/codecov-action@v1 + uses: codecov/codecov-action@v2.0.2 with: file: ./coverage.xml flags: unittests env_vars: RUNNER_OS,PYTHON_VERSION name: codecov-umbrella fail_ci_if_error: false + + publish-test-results: + needs: test + runs-on: ubuntu-latest + # the build-and-test job might be skipped, we don't need to run this job then + if: success() || failure() + + steps: + - name: Download Artifacts + uses: actions/download-artifact@v2 + with: + path: test-results + + - name: Publish Unit Test Results + uses: EnricoMi/publish-unit-test-result-action@v1 + with: + files: test-results/**/*.xml diff --git a/.github/workflows/parse_logs.py b/.github/workflows/parse_logs.py index 4d3bea54e50..545beaa4167 100644 --- a/.github/workflows/parse_logs.py +++ b/.github/workflows/parse_logs.py @@ -18,7 +18,7 @@ def extract_short_test_summary_info(lines): ) up_to_section_content = itertools.islice(up_to_start_of_section, 1, None) section_content = itertools.takewhile( - lambda l: l.startswith("FAILED"), up_to_section_content + lambda l: l.startswith("FAILED") or l.startswith("ERROR"), up_to_section_content ) content = "\n".join(section_content) diff --git a/.github/workflows/publish-test-results.yaml b/.github/workflows/publish-test-results.yaml new file mode 100644 index 00000000000..485383b31b4 --- /dev/null +++ b/.github/workflows/publish-test-results.yaml @@ -0,0 +1,44 @@ +# Copied from https://github.com/EnricoMi/publish-unit-test-result-action/blob/v1.18/README.md#support-fork-repositories-and-dependabot-branches + +name: Publish test results + +on: + workflow_run: + workflows: ["CI"] + types: + - completed + +jobs: + publish-test-results: + name: Publish test results + runs-on: ubuntu-latest + if: > + github.event.workflow_run.conclusion != 'skipped' && ( + github.event.sender.login == 'dependabot[bot]' || + github.event.workflow_run.head_repository.full_name != github.repository + ) + + steps: + - name: Download and extract artifacts + env: + GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} + run: | + mkdir artifacts && cd artifacts + + artifacts_url=${{ github.event.workflow_run.artifacts_url }} + artifacts=$(gh api $artifacts_url -q '.artifacts[] | {name: .name, url: .archive_download_url}') + + IFS=$'\n' + for artifact in $artifacts + do + name=$(jq -r .name <<<$artifact) + url=$(jq -r .url <<<$artifact) + gh api $url > "$name.zip" + unzip -d "$name" "$name.zip" + done + + - name: Publish Unit Test Results + uses: EnricoMi/publish-unit-test-result-action@v1 + with: + commit: ${{ github.event.workflow_run.head_sha }} + files: "artifacts/**/*.xml" diff --git a/.github/workflows/pypi-release.yaml b/.github/workflows/pypi-release.yaml new file mode 100644 index 00000000000..432aea8a375 --- /dev/null +++ b/.github/workflows/pypi-release.yaml @@ -0,0 +1,96 @@ +name: Build and Upload xarray to PyPI +on: + release: + types: + - published + push: + tags: + - 'v*' + +jobs: + build-artifacts: + runs-on: ubuntu-latest + if: github.repository == 'pydata/xarray' + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + - uses: actions/setup-python@v2 + name: Install Python + with: + python-version: 3.8 + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install setuptools setuptools-scm wheel twine check-manifest + + - name: Build tarball and wheels + run: | + git clean -xdf + git restore -SW . + python -m build --sdist --wheel . + + - name: Check built artifacts + run: | + python -m twine check dist/* + pwd + if [ -f dist/xarray-0.0.0.tar.gz ]; then + echo "❌ INVALID VERSION NUMBER" + exit 1 + else + echo "✅ Looks good" + fi + - uses: actions/upload-artifact@v2 + with: + name: releases + path: dist + + test-built-dist: + needs: build-artifacts + runs-on: ubuntu-latest + steps: + - uses: actions/setup-python@v2 + name: Install Python + with: + python-version: 3.8 + - uses: actions/download-artifact@v2 + with: + name: releases + path: dist + - name: List contents of built dist + run: | + ls -ltrh + ls -ltrh dist + - name: Publish package to TestPyPI + if: github.event_name == 'push' + uses: pypa/gh-action-pypi-publish@v1.4.2 + with: + user: __token__ + password: ${{ secrets.TESTPYPI_TOKEN }} + repository_url: https://test.pypi.org/legacy/ + verbose: true + + - name: Check uploaded package + if: github.event_name == 'push' + run: | + sleep 3 + python -m pip install --upgrade pip + python -m pip install --extra-index-url https://test.pypi.org/simple --upgrade xarray + python -m xarray.util.print_versions + + upload-to-pypi: + needs: test-built-dist + if: github.event_name == 'release' + runs-on: ubuntu-latest + steps: + - uses: actions/download-artifact@v2 + with: + name: releases + path: dist + - name: Publish package to PyPI + uses: pypa/gh-action-pypi-publish@v1.4.2 + with: + user: __token__ + password: ${{ secrets.PYPI_TOKEN }} + verbose: true diff --git a/.github/workflows/upstream-dev-ci.yaml b/.github/workflows/upstream-dev-ci.yaml index dda762878c5..9b1664a0292 100644 --- a/.github/workflows/upstream-dev-ci.yaml +++ b/.github/workflows/upstream-dev-ci.yaml @@ -2,10 +2,10 @@ name: CI Upstream on: push: branches: - - master + - main pull_request: branches: - - master + - main schedule: - cron: "0 0 * * *" # Daily “At 00:00” UTC workflow_dispatch: # allows you to trigger the workflow run manually @@ -14,14 +14,16 @@ jobs: detect-ci-trigger: name: detect upstream-dev ci trigger runs-on: ubuntu-latest - if: github.event_name == 'push' || github.event_name == 'pull_request' + if: | + github.repository == 'pydata/xarray' + && (github.event_name == 'push' || github.event_name == 'pull_request') outputs: triggered: ${{ steps.detect-trigger.outputs.trigger-found }} steps: - uses: actions/checkout@v2 with: fetch-depth: 2 - - uses: ./.github/actions/detect-ci-trigger + - uses: xarray-contrib/ci-trigger@v1.1 id: detect-trigger with: keyword: "[test-upstream]" @@ -31,27 +33,24 @@ jobs: runs-on: ubuntu-latest needs: detect-ci-trigger if: | - always() - && github.repository == 'pydata/xarray' - && ( - (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') - || needs.detect-ci-trigger.outputs.triggered == 'true' - ) + always() + && ( + (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') + || needs.detect-ci-trigger.outputs.triggered == 'true' + ) defaults: run: shell: bash -l {0} strategy: fail-fast: false matrix: - python-version: ["3.8"] + python-version: ["3.9"] outputs: artifacts_availability: ${{ steps.status.outputs.ARTIFACTS_AVAILABLE }} steps: - - name: Cancel previous runs - uses: styfle/cancel-workflow-action@0.6.0 - with: - access_token: ${{ github.token }} - uses: actions/checkout@v2 + with: + fetch-depth: 0 # Fetch all history for all branches and tags. - uses: conda-incubator/setup-miniconda@v2 with: channels: conda-forge @@ -64,6 +63,9 @@ jobs: run: | mamba env update -f ci/requirements/environment.yml bash ci/install-upstream-wheels.sh + - name: Install xarray + run: | + python -m pip install --no-deps -e . - name: Version info run: | conda info -a @@ -77,7 +79,7 @@ jobs: id: status run: | set -euo pipefail - python -m pytest -rf | tee output-${{ matrix.python-version }}-log || ( + python -m pytest --timeout=60 -rf | tee output-${{ matrix.python-version }}-log || ( echo '::set-output name=ARTIFACTS_AVAILABLE::true' && false ) - name: Upload artifacts @@ -96,9 +98,8 @@ jobs: name: report needs: upstream-dev if: | - always() + failure() && github.event_name == 'schedule' - && github.repository == 'pydata/xarray' && needs.upstream-dev.outputs.artifacts_availability == 'true' runs-on: ubuntu-latest defaults: @@ -121,7 +122,7 @@ jobs: shopt -s globstar python .github/workflows/parse_logs.py logs/**/*-log - name: Report failures - uses: actions/github-script@v3 + uses: actions/github-script@v4.0.2 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | diff --git a/.gitignore b/.gitignore index 5f02700de37..90f4a10ed5f 100644 --- a/.gitignore +++ b/.gitignore @@ -65,10 +65,11 @@ dask-worker-space/ # xarray specific doc/_build -doc/generated +generated/ xarray/tests/data/*.grib.*.idx # Sync tools Icon* .ipynb_checkpoints +doc/rasm.zarr diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b0fa21a7bf9..53525d0def9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,34 +1,50 @@ # https://pre-commit.com/ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.4.0 + rev: v4.0.1 hooks: - id: trailing-whitespace - id: end-of-file-fixer - id: check-yaml # isort should run before black as black sometimes tweaks the isort output - repo: https://github.com/PyCQA/isort - rev: 5.7.0 + rev: 5.9.3 hooks: - id: isort # https://github.com/python/black#version-control-integration - repo: https://github.com/psf/black - rev: 20.8b1 + rev: 21.7b0 hooks: - id: black - repo: https://github.com/keewis/blackdoc - rev: v0.3.2 + rev: v0.3.4 hooks: - id: blackdoc - repo: https://gitlab.com/pycqa/flake8 - rev: 3.8.4 + rev: 3.9.2 hooks: - id: flake8 + # - repo: https://github.com/Carreau/velin + # rev: 0.0.8 + # hooks: + # - id: velin + # args: ["--write", "--compact"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.790 # Must match ci/requirements/*.yml + rev: v0.910 hooks: - id: mypy + # Copied from setup.cfg exclude: "properties|asv_bench" + additional_dependencies: [ + # Type stubs + types-python-dateutil, + types-pkg_resources, + types-PyYAML, + types-pytz, + # Dependencies that are typed + numpy, + typing-extensions==3.10.0.0, + ] # run this occasionally, ref discussion https://github.com/pydata/xarray/pull/3194 # - repo: https://github.com/asottile/pyupgrade # rev: v1.22.1 diff --git a/HOW_TO_RELEASE.md b/HOW_TO_RELEASE.md index 5352d427909..16dc3b94196 100644 --- a/HOW_TO_RELEASE.md +++ b/HOW_TO_RELEASE.md @@ -1,4 +1,4 @@ -# How to issue an xarray release in 20 easy steps +# How to issue an xarray release in 16 easy steps Time required: about an hour. @@ -13,31 +13,27 @@ upstream https://github.com/pydata/xarray (push) - 1. Ensure your master branch is synced to upstream: + 1. Ensure your main branch is synced to upstream: ```sh - git switch master - git pull upstream master + git switch main + git pull upstream main ``` 2. Confirm there are no commits on stable that are not yet merged ([ref](https://github.com/pydata/xarray/pull/4440)): ```sh - git merge upstream stable + git merge upstream/stable ``` - 2. Add a list of contributors with: + 3. Add a list of contributors with: ```sh - git log "$(git tag --sort="v:refname" | sed -n 'x;$p').." --format=%aN | sort -u | perl -pe 's/\n/$1, /' - ``` - or by substituting the _previous_ release in {0.X.Y-1}: - ```sh - git log v{0.X.Y-1}.. --format=%aN | sort -u | perl -pe 's/\n/$1, /' + git log "$(git tag --sort="v:refname" | tail -1).." --format=%aN | sort -u | perl -pe 's/\n/$1, /' ``` This will return the number of contributors: ```sh - git log v{0.X.Y-1}.. --format=%aN | sort -u | wc -l + git log $(git tag --sort="v:refname" | tail -1).. --format=%aN | sort -u | wc -l ``` - 3. Write a release summary: ~50 words describing the high level features. This + 4. Write a release summary: ~50 words describing the high level features. This will be used in the release emails, tweets, GitHub release notes, etc. - 4. Look over whats-new.rst and the docs. Make sure "What's New" is complete + 5. Look over whats-new.rst and the docs. Make sure "What's New" is complete (check the date!) and add the release summary at the top. Things to watch out for: - Important new features should be highlighted towards the top. @@ -46,67 +42,52 @@ upstream https://github.com/pydata/xarray (push) due to a bad merge. Check for these before a release by using git diff, e.g., `git diff v{0.X.Y-1} whats-new.rst` where {0.X.Y-1} is the previous release. - 5. If possible, open a PR with the release summary and whatsnew changes. - 6. After merging, again ensure your master branch is synced to upstream: + 6. Open a PR with the release summary and whatsnew changes; in particular the + release headline should get feedback from the team on what's important to include. + 7. After merging, again ensure your main branch is synced to upstream: ```sh - git pull upstream master + git pull upstream main ``` - 7. If you have any doubts, run the full test suite one final time! + 8. If you have any doubts, run the full test suite one final time! ```sh pytest ``` - 8. Check that the ReadTheDocs build is passing. - 9. On the master branch, commit the release in git: - ```sh - git commit -am 'Release v{0.X.Y}' - ``` -10. Tag the release: - ```sh - git tag -a v{0.X.Y} -m 'v{0.X.Y}' - ``` -11. Build source and binary wheels for PyPI: - ```sh - git clean -xdf # this deletes all uncommitted changes! - python setup.py bdist_wheel sdist - ``` -12. Use twine to check the package build: - ```sh - twine check dist/xarray-{0.X.Y}* - ``` -13. Use twine to register and upload the release on PyPI. Be careful, you can't - take this back! - ```sh - twine upload dist/xarray-{0.X.Y}* - ``` - You will need to be listed as a package owner at - for this to work. -14. Push your changes to master: - ```sh - git push upstream master - git push upstream --tags - ``` -15. Update the stable branch (used by ReadTheDocs) and switch back to master: + 9. Check that the ReadTheDocs build is passing. +10. Issue the release on GitHub. Click on "Draft a new release" at + . Type in the version number (with a "v") + and paste the release summary in the notes. +11. This should automatically trigger an upload of the new build to PyPI via GitHub Actions. + Check this has run [here](https://github.com/pydata/xarray/actions/workflows/pypi-release.yaml), + and that the version number you expect is displayed [on PyPI](https://pypi.org/project/xarray/) +12. Update the stable branch (used by ReadTheDocs) and switch back to main: ```sh git switch stable - git rebase master + git rebase main git push --force upstream stable - git switch master + git switch main ``` + You may need to first fetch it with `git fetch upstream`, + and check out a local version with `git checkout -b stable upstream/stable`. + It's OK to force push to `stable` if necessary. (We also update the stable branch with `git cherry-pick` for documentation only fixes that apply the current released version.) -16. Add a section for the next release {0.X.Y+1} to doc/whats-new.rst: +13. Add a section for the next release {0.X.Y+1} to doc/whats-new.rst: ```rst - .. _whats-new.{0.X.Y+1}: + .. _whats-new.0.X.Y+1: - v{0.X.Y+1} (unreleased) + v0.X.Y+1 (unreleased) --------------------- + New Features + ~~~~~~~~~~~~ + + Breaking changes ~~~~~~~~~~~~~~~~ - New Features + Deprecations ~~~~~~~~~~~~ @@ -120,20 +101,19 @@ upstream https://github.com/pydata/xarray (push) Internal Changes ~~~~~~~~~~~~~~~~ + ``` -17. Commit your changes and push to master again: +14. Commit your changes and push to main again: ```sh git commit -am 'New whatsnew section' - git push upstream master + git push upstream main ``` - You're done pushing to master! -18. Issue the release on GitHub. Click on "Draft a new release" at - . Type in the version number - and paste the release summary in the notes. -19. Update the docs. Login to + You're done pushing to main! + +15. Update the docs. Login to and switch your new release tag (at the bottom) from "Inactive" to "Active". It should now build automatically. -20. Issue the release announcement to mailing lists & Twitter. For bug fix releases, I +16. Issue the release announcement to mailing lists & Twitter. For bug fix releases, I usually only email xarray@googlegroups.com. For major/feature releases, I will email a broader list (no more than once every 3-6 months): - pydata@googlegroups.com @@ -144,6 +124,7 @@ upstream https://github.com/pydata/xarray (push) Google search will turn up examples of prior release announcements (look for "ANN xarray"). + Some of these groups require you to be subscribed in order to email them. diff --git a/README.rst b/README.rst index e258a8ccd23..e246d5474c9 100644 --- a/README.rst +++ b/README.rst @@ -1,9 +1,9 @@ xarray: N-D labeled arrays and datasets ======================================= -.. image:: https://github.com/pydata/xarray/workflows/CI/badge.svg?branch=master +.. image:: https://github.com/pydata/xarray/workflows/CI/badge.svg?branch=main :target: https://github.com/pydata/xarray/actions?query=workflow%3ACI -.. image:: https://codecov.io/gh/pydata/xarray/branch/master/graph/badge.svg +.. image:: https://codecov.io/gh/pydata/xarray/branch/main/graph/badge.svg :target: https://codecov.io/gh/pydata/xarray .. image:: https://readthedocs.org/projects/xray/badge/?version=latest :target: https://xarray.pydata.org/ diff --git a/asv_bench/asv.conf.json b/asv_bench/asv.conf.json index d35a2a223a2..83a2aa9f010 100644 --- a/asv_bench/asv.conf.json +++ b/asv_bench/asv.conf.json @@ -15,7 +15,7 @@ // List of branches to benchmark. If not provided, defaults to "master" // (for git) or "default" (for mercurial). - "branches": ["master"], // for git + "branches": ["main"], // for git // "branches": ["default"], // for mercurial // The DVCS being used. If not set, it will be automatically diff --git a/asv_bench/benchmarks/combine.py b/asv_bench/benchmarks/combine.py index aa9662d44f9..308ca2afda4 100644 --- a/asv_bench/benchmarks/combine.py +++ b/asv_bench/benchmarks/combine.py @@ -26,13 +26,13 @@ def setup(self): {"B": xr.DataArray(data, coords={"T": t + t_size}, dims=("T", "X", "Y"))} ) - def time_combine_manual(self): + def time_combine_nested(self): datasets = [[self.dsA0, self.dsA1], [self.dsB0, self.dsB1]] - xr.combine_manual(datasets, concat_dim=[None, "t"]) + xr.combine_nested(datasets, concat_dim=[None, "T"]) - def time_auto_combine(self): + def time_combine_by_coords(self): """Also has to load and arrange t coordinate""" datasets = [self.dsA0, self.dsA1, self.dsB0, self.dsB1] - xr.combine_auto(datasets) + xr.combine_by_coords(datasets) diff --git a/asv_bench/benchmarks/dataset_io.py b/asv_bench/benchmarks/dataset_io.py index d1ffbc34706..e99911d752c 100644 --- a/asv_bench/benchmarks/dataset_io.py +++ b/asv_bench/benchmarks/dataset_io.py @@ -59,7 +59,6 @@ def make_ds(self): coords={"lon": lons, "lat": lats, "time": times}, dims=("time", "lon", "lat"), name="foo", - encoding=None, attrs={"units": "foo units", "description": "a description"}, ) self.ds["bar"] = xr.DataArray( @@ -67,7 +66,6 @@ def make_ds(self): coords={"lon": lons, "lat": lats, "time": times}, dims=("time", "lon", "lat"), name="bar", - encoding=None, attrs={"units": "bar units", "description": "a description"}, ) self.ds["baz"] = xr.DataArray( @@ -75,7 +73,6 @@ def make_ds(self): coords={"lon": lons, "lat": lats}, dims=("lon", "lat"), name="baz", - encoding=None, attrs={"units": "baz units", "description": "a description"}, ) @@ -270,7 +267,6 @@ def make_ds(self, nfiles=10): coords={"lon": lons, "lat": lats, "time": times}, dims=("time", "lon", "lat"), name="foo", - encoding=None, attrs={"units": "foo units", "description": "a description"}, ) ds["bar"] = xr.DataArray( @@ -278,7 +274,6 @@ def make_ds(self, nfiles=10): coords={"lon": lons, "lat": lats, "time": times}, dims=("time", "lon", "lat"), name="bar", - encoding=None, attrs={"units": "bar units", "description": "a description"}, ) ds["baz"] = xr.DataArray( @@ -286,7 +281,6 @@ def make_ds(self, nfiles=10): coords={"lon": lons, "lat": lats}, dims=("lon", "lat"), name="baz", - encoding=None, attrs={"units": "baz units", "description": "a description"}, ) diff --git a/asv_bench/benchmarks/repr.py b/asv_bench/benchmarks/repr.py index b218c0be870..405f6cd0530 100644 --- a/asv_bench/benchmarks/repr.py +++ b/asv_bench/benchmarks/repr.py @@ -1,10 +1,32 @@ +import numpy as np import pandas as pd import xarray as xr +class Repr: + def setup(self): + a = np.arange(0, 100) + data_vars = dict() + for i in a: + data_vars[f"long_variable_name_{i}"] = xr.DataArray( + name=f"long_variable_name_{i}", + data=np.arange(0, 20), + dims=[f"long_coord_name_{i}_x"], + coords={f"long_coord_name_{i}_x": np.arange(0, 20) * 2}, + ) + self.ds = xr.Dataset(data_vars) + self.ds.attrs = {f"attr_{k}": 2 for k in a} + + def time_repr(self): + repr(self.ds) + + def time_repr_html(self): + self.ds._repr_html_() + + class ReprMultiIndex: - def setup(self, key): + def setup(self): index = pd.MultiIndex.from_product( [range(10000), range(10000)], names=("level_0", "level_1") ) diff --git a/asv_bench/benchmarks/rolling.py b/asv_bench/benchmarks/rolling.py index d5426af4aa1..93c3c6aed4e 100644 --- a/asv_bench/benchmarks/rolling.py +++ b/asv_bench/benchmarks/rolling.py @@ -67,3 +67,44 @@ def setup(self, *args, **kwargs): super().setup(**kwargs) self.ds = self.ds.chunk({"x": 100, "y": 50, "t": 50}) self.da_long = self.da_long.chunk({"x": 10000}) + + +class RollingMemory: + def setup(self, *args, **kwargs): + self.ds = xr.Dataset( + { + "var1": (("x", "y"), randn_xy), + "var2": (("x", "t"), randn_xt), + "var3": (("t",), randn_t), + }, + coords={ + "x": np.arange(nx), + "y": np.linspace(0, 1, ny), + "t": pd.date_range("1970-01-01", periods=nt, freq="D"), + "x_coords": ("x", np.linspace(1.1, 2.1, nx)), + }, + ) + + +class DataArrayRollingMemory(RollingMemory): + @parameterized("func", ["sum", "max", "mean"]) + def peakmem_ndrolling_reduce(self, func): + roll = self.ds.var1.rolling(x=10, y=4) + getattr(roll, func)() + + @parameterized("func", ["sum", "max", "mean"]) + def peakmem_1drolling_reduce(self, func): + roll = self.ds.var3.rolling(t=100) + getattr(roll, func)() + + +class DatasetRollingMemory(RollingMemory): + @parameterized("func", ["sum", "max", "mean"]) + def peakmem_ndrolling_reduce(self, func): + roll = self.ds.rolling(x=10, y=4) + getattr(roll, func)() + + @parameterized("func", ["sum", "max", "mean"]) + def peakmem_1drolling_reduce(self, func): + roll = self.ds.rolling(t=100) + getattr(roll, func)() diff --git a/ci/install-upstream-wheels.sh b/ci/install-upstream-wheels.sh index fe3e706f6a6..92a0f8fc7e7 100755 --- a/ci/install-upstream-wheels.sh +++ b/ci/install-upstream-wheels.sh @@ -1,8 +1,5 @@ #!/usr/bin/env bash -# TODO: add sparse back in, once Numba works with the development version of -# NumPy again: https://github.com/pydata/xarray/issues/4146 - conda uninstall -y --force \ numpy \ scipy \ @@ -10,12 +7,16 @@ conda uninstall -y --force \ matplotlib \ dask \ distributed \ + fsspec \ zarr \ cftime \ rasterio \ pint \ bottleneck \ - sparse + sparse \ + xarray +# to limit the runtime of Upstream CI +python -m pip install pytest-timeout python -m pip install \ -i https://pypi.anaconda.org/scipy-wheels-nightly/simple \ --no-deps \ @@ -39,5 +40,6 @@ python -m pip install \ git+https://github.com/Unidata/cftime \ git+https://github.com/mapbox/rasterio \ git+https://github.com/hgrecco/pint \ - git+https://github.com/pydata/bottleneck # \ - # git+https://github.com/pydata/sparse + git+https://github.com/pydata/bottleneck \ + git+https://github.com/pydata/sparse \ + git+https://github.com/intake/filesystem_spec diff --git a/ci/min_deps_check.py b/ci/min_deps_check.py index 3ffab645e8e..d2560fc9106 100755 --- a/ci/min_deps_check.py +++ b/ci/min_deps_check.py @@ -4,11 +4,12 @@ """ import itertools import sys -from datetime import datetime, timedelta +from datetime import datetime from typing import Dict, Iterator, Optional, Tuple -import conda.api +import conda.api # type: ignore[import] import yaml +from dateutil.relativedelta import relativedelta CHANNELS = ["conda-forge", "defaults"] IGNORE_DEPS = { @@ -25,14 +26,9 @@ "pytest-xdist", } -POLICY_MONTHS = {"python": 42, "numpy": 24, "setuptools": 42} +POLICY_MONTHS = {"python": 24, "numpy": 18, "setuptools": 42} POLICY_MONTHS_DEFAULT = 12 POLICY_OVERRIDE = { - # dask < 2.9 has trouble with nan-reductions - # TODO remove this special case and the matching note in installing.rst - # after January 2021. - "dask": (2, 9), - "distributed": (2, 9), # setuptools-scm doesn't work with setuptools < 36.7 (Nov 2017). # The conda metadata is malformed for setuptools < 38.4 (Jan 2018) # (it's missing a timestamp which prevents this tool from working). @@ -80,9 +76,9 @@ def parse_requirements(fname) -> Iterator[Tuple[str, int, int, Optional[int]]]: raise ValueError("non-numerical version: " + row) if len(version_tup) == 2: - yield (pkg, *version_tup, None) # type: ignore + yield (pkg, *version_tup, None) # type: ignore[misc] elif len(version_tup) == 3: - yield (pkg, *version_tup) # type: ignore + yield (pkg, *version_tup) # type: ignore[misc] else: raise ValueError("expected major.minor or major.minor.patch: " + row) @@ -148,28 +144,32 @@ def process_pkg( return pkg, fmt_version(req_major, req_minor, req_patch), "-", "-", "-", "(!)" policy_months = POLICY_MONTHS.get(pkg, POLICY_MONTHS_DEFAULT) - policy_published = datetime.now() - timedelta(days=policy_months * 30) - - policy_major = req_major - policy_minor = req_minor - policy_published_actual = req_published - for (major, minor), published in reversed(sorted(versions.items())): - if published < policy_published: - break - policy_major = major - policy_minor = minor - policy_published_actual = published + policy_published = datetime.now() - relativedelta(months=policy_months) + + filtered_versions = [ + version + for version, published in versions.items() + if published < policy_published + ] + policy_major, policy_minor = max(filtered_versions, default=(req_major, req_minor)) try: policy_major, policy_minor = POLICY_OVERRIDE[pkg] except KeyError: pass + policy_published_actual = versions[policy_major, policy_minor] if (req_major, req_minor) < (policy_major, policy_minor): status = "<" elif (req_major, req_minor) > (policy_major, policy_minor): status = "> (!)" - error("Package is too new: " + pkg) + delta = relativedelta(datetime.now(), policy_published_actual).normalized() + n_months = delta.years * 12 + delta.months + error( + f"Package is too new: {pkg}={req_major}.{req_minor} was " + f"published on {versions[req_major, req_minor]:%Y-%m-%d} " + f"which was {n_months} months ago (policy is {policy_months} months)" + ) else: status = "=" diff --git a/ci/requirements/doc.yml b/ci/requirements/doc.yml index e092272654b..a73c6679322 100644 --- a/ci/requirements/doc.yml +++ b/ci/requirements/doc.yml @@ -20,15 +20,21 @@ dependencies: - numba - numpy>=1.17 - pandas>=1.0 + - pooch + - pip + - pydata-sphinx-theme>=0.4.3 - rasterio>=1.1 - seaborn - setuptools - - sphinx=3.3 - - sphinx_rtd_theme>=0.4 + - sparse - sphinx-autosummary-accessors + - sphinx-book-theme >= 0.0.38 + - sphinx-copybutton + - sphinx-panels + - sphinx<4 - zarr>=2.4 - - pip - pip: - - scanpydoc + - sphinxext-rediraffe + - sphinxext-opengraph # relative to this file. Needs to be editable to be accepted. - -e ../.. diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index 6de2bc8dc64..78ead40d5a2 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -10,8 +10,9 @@ dependencies: - cftime - dask - distributed + - fsspec!=2021.7.0 - h5netcdf - - h5py=2 + - h5py - hdf5 - hypothesis - iris diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index 0f59d9570c8..f64ca3677cc 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -3,6 +3,7 @@ channels: - conda-forge - nodefaults dependencies: + - aiobotocore - boto3 - bottleneck - cartopy @@ -11,8 +12,9 @@ dependencies: - cftime - dask - distributed + - fsspec!=2021.7.0 - h5netcdf - - h5py=2 + - h5py - hdf5 - hypothesis - iris @@ -21,10 +23,12 @@ dependencies: - nc-time-axis - netcdf4 - numba + - numexpr - numpy - pandas - pint - - pip=20.2 + - pip + - pooch - pre-commit - pseudonetcdf - pydap diff --git a/ci/requirements/py37-bare-minimum.yml b/ci/requirements/py37-bare-minimum.yml index fbeb87032b7..408cf76fdd6 100644 --- a/ci/requirements/py37-bare-minimum.yml +++ b/ci/requirements/py37-bare-minimum.yml @@ -10,6 +10,6 @@ dependencies: - pytest-cov - pytest-env - pytest-xdist - - numpy=1.15 - - pandas=0.25 + - numpy=1.17 + - pandas=1.0 - setuptools=40.4 diff --git a/ci/requirements/py37-min-all-deps.yml b/ci/requirements/py37-min-all-deps.yml index feef86ddf5c..7743a086db0 100644 --- a/ci/requirements/py37-min-all-deps.yml +++ b/ci/requirements/py37-min-all-deps.yml @@ -8,46 +8,44 @@ dependencies: # When upgrading python, numpy, or pandas, must also change # doc/installing.rst and setup.py. - python=3.7 - - black - - boto3=1.9 - - bottleneck=1.2 + - boto3=1.13 + - bottleneck=1.3 - cartopy=0.17 - cdms2=3.1 - cfgrib=0.9 - - cftime=1.0 + - cftime=1.1 - coveralls - - dask=2.9 - - distributed=2.9 - - flake8 - - h5netcdf=0.7 - - h5py=2.9 # Policy allows for 2.10, but it's a conflict-fest + - dask=2.15 + - distributed=2.15 + - h5netcdf=0.8 + - h5py=2.10 - hdf5=1.10 - hypothesis - - iris=2.2 - - isort - - lxml=4.4 # Optional dep of pydap - - matplotlib-base=3.1 - - mypy=0.782 # Must match .pre-commit-config.yaml + - iris=2.4 + - lxml=4.5 # Optional dep of pydap + - matplotlib-base=3.2 - nc-time-axis=1.2 - - netcdf4=1.4 - - numba=0.46 - - numpy=1.15 - - pandas=0.25 +# netcdf follows a 1.major.minor[.patch] convention (see https://github.com/Unidata/netcdf4-python/issues/1090) +# bumping the netCDF4 version is currently blocked by #4491 + - netcdf4=1.5.3 + - numba=0.49 + - numpy=1.17 + - pandas=1.0 # - pint # See py37-min-nep18.yml - pip - - pseudonetcdf=3.0 + - pseudonetcdf=3.1 - pydap=3.2 - pynio=1.5 - pytest - pytest-cov - pytest-env - pytest-xdist - - rasterio=1.0 - - scipy=1.3 - - seaborn=0.9 + - rasterio=1.1 + - scipy=1.4 + - seaborn=0.10 - setuptools=40.4 # - sparse # See py37-min-nep18.yml - toolz=0.10 - - zarr=2.3 + - zarr=2.4 - pip: - numbagg==0.1 diff --git a/ci/requirements/py38-all-but-dask.yml b/ci/requirements/py38-all-but-dask.yml index 14930f5272d..3f82990f3b5 100644 --- a/ci/requirements/py38-all-but-dask.yml +++ b/ci/requirements/py38-all-but-dask.yml @@ -5,6 +5,7 @@ channels: dependencies: - python=3.8 - black + - aiobotocore - boto3 - bottleneck - cartopy @@ -12,15 +13,12 @@ dependencies: - cfgrib - cftime - coveralls - - flake8 - h5netcdf - - h5py=2 + - h5py - hdf5 - hypothesis - - isort - lxml # Optional dep of pydap - matplotlib-base - - mypy=0.790 # Must match .pre-commit-config.yaml - nc-time-axis - netcdf4 - numba diff --git a/design_notes/flexible_indexes_notes.md b/design_notes/flexible_indexes_notes.md new file mode 100644 index 00000000000..c7eb718720c --- /dev/null +++ b/design_notes/flexible_indexes_notes.md @@ -0,0 +1,398 @@ +# Proposal: Xarray flexible indexes refactoring + +Current status: https://github.com/pydata/xarray/projects/1 + +## 1. Data Model + +Indexes are used in Xarray to extract data from Xarray objects using coordinate labels instead of using integer array indices. Although the indexes used in an Xarray object can be accessed (or built on-the-fly) via public methods like `to_index()` or properties like `indexes`, those are mainly used internally. + +The goal of this project is to make those indexes 1st-class citizens of Xarray's data model. As such, indexes should clearly be separated from Xarray coordinates with the following relationships: + +- Index -> Coordinate: one-to-many +- Coordinate -> Index: one-to-zero-or-one + +An index may be built from one or more coordinates. However, each coordinate must relate to one index at most. Additionally, a coordinate may not be tied to any index. + +The order in which multiple coordinates relate to an index may matter. For example, Scikit-Learn's [`BallTree`](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.BallTree.html#sklearn.neighbors.BallTree) index with the Haversine metric requires providing latitude and longitude values in that specific order. As another example, the order in which levels are defined in a `pandas.MultiIndex` may affect its lexsort depth (see [MultiIndex sorting](https://pandas.pydata.org/pandas-docs/stable/user_guide/advanced.html#sorting-a-multiindex)). + +Xarray's current data model has the same index-coordinate relationships than stated above, although this assumes that multi-index "virtual" coordinates are counted as coordinates (we can consider them as such, with some constraints). More importantly, This refactoring would turn the current one-to-one relationship between a dimension and an index into a many-to-many relationship, which would overcome some current limitations. + +For example, we might want to select data along a dimension which has several coordinates: + +```python +>>> da + +array([...]) +Coordinates: + * drainage_area (river_profile) float64 ... + * chi (river_profile) float64 ... +``` + +In this example, `chi` is a transformation of the `drainage_area` variable that is often used in geomorphology. We'd like to select data along the river profile using either `da.sel(drainage_area=...)` or `da.sel(chi=...)` but that's not currently possible. We could rename the `river_profile` dimension to one of the coordinates, then use `sel` with that coordinate, then call `swap_dims` if we want to use `sel` with the other coordinate, but that's not ideal. We could also build a `pandas.MultiIndex` from `drainage_area` and `chi`, but that's not optimal (there's no hierarchical relationship between these two coordinates). + +Let's take another example: + +```python +>>> da + +array([[...], [...]]) +Coordinates: + * lon (x, y) float64 ... + * lat (x, y) float64 ... + * x (x) float64 ... + * y (y) float64 ... +``` + +This refactoring would allow creating a geographic index for `lat` and `lon` and two simple indexes for `x` and `y` such that we could select data with either `da.sel(lon=..., lat=...)` or `da.sel(x=..., y=...)`. + +Refactoring the dimension -> index one-to-one relationship into many-to-many would also introduce some issues that we'll need to address, e.g., ambiguous cases like `da.sel(chi=..., drainage_area=...)` where multiple indexes may potentially return inconsistent positional indexers along a dimension. + +## 2. Proposed API changes + +### 2.1 Index wrapper classes + +Every index that is used to select data from Xarray objects should inherit from a base class, e.g., `XarrayIndex`, that provides some common API. `XarrayIndex` subclasses would generally consist of thin wrappers around existing index classes such as `pandas.Index`, `pandas.MultiIndex`, `scipy.spatial.KDTree`, etc. + +There is a variety of features that an xarray index wrapper may or may not support: + +- 1-dimensional vs. 2-dimensional vs. n-dimensional coordinate (e.g., `pandas.Index` only supports 1-dimensional coordinates while a geographic index could be built from n-dimensional coordinates) +- built from a single vs multiple coordinate(s) (e.g., `pandas.Index` is built from one coordinate, `pandas.MultiIndex` may be built from an arbitrary number of coordinates and a geographic index would typically require two latitude/longitude coordinates) +- in-memory vs. out-of-core (dask) index data/coordinates (vs. other array backends) +- range-based vs. point-wise selection +- exact vs. inexact lookups + +Whether or not a `XarrayIndex` subclass supports each of the features listed above should be either declared explicitly via a common API or left to the implementation. An `XarrayIndex` subclass may encapsulate more than one underlying object used to perform the actual indexing. Such "meta" index would typically support a range of features among those mentioned above and would automatically select the optimal index object for a given indexing operation. + +An `XarrayIndex` subclass must/should/may implement the following properties/methods: + +- a `from_coords` class method that creates a new index wrapper instance from one or more Dataset/DataArray coordinates (+ some options) +- a `query` method that takes label-based indexers as argument (+ some options) and that returns the corresponding position-based indexers +- an `indexes` property to access the underlying index object(s) wrapped by the `XarrayIndex` subclass +- a `data` property to access index's data and map it to coordinate data (see [Section 4](#4-indexvariable)) +- a `__getitem__()` implementation to propagate the index through DataArray/Dataset indexing operations +- `equals()`, `union()` and `intersection()` methods for data alignment (see [Section 2.6](#26-using-indexes-for-data-alignment)) +- Xarray coordinate getters (see [Section 2.2.4](#224-implicit-coodinates)) +- a method that may return a new index and that will be called when one of the corresponding coordinates is dropped from the Dataset/DataArray (multi-coordinate indexes) +- `encode()`/`decode()` methods that would allow storage-agnostic serialization and fast-path reconstruction of the underlying index object(s) (see [Section 2.8](#28-index-encoding)) +- one or more "non-standard" methods or properties that could be leveraged in Xarray 3rd-party extensions like Dataset/DataArray accessors (see [Section 2.7](#27-using-indexes-for-other-purposes)) + +The `XarrayIndex` API has still to be defined in detail. + +Xarray should provide a minimal set of built-in index wrappers (this could be reduced to the indexes currently supported in Xarray, i.e., `pandas.Index` and `pandas.MultiIndex`). Other index wrappers may be implemented in 3rd-party libraries (recommended). The `XarrayIndex` base class should be part of Xarray's public API. + +#### 2.1.1 Index discoverability + +For better discoverability of Xarray-compatible indexes, Xarray could provide some mechanism to register new index wrappers, e.g., something like [xoak's `IndexRegistry`](https://xoak.readthedocs.io/en/latest/_api_generated/xoak.IndexRegistry.html#xoak.IndexRegistry) or [numcodec's registry](https://numcodecs.readthedocs.io/en/stable/registry.html). + +Additionally (or alternatively), new index wrappers may be registered via entry points as is already the case for storage backends and maybe other backends (plotting) in the future. + +Registering new indexes either via a custom registry or via entry points should be optional. Xarray should also allow providing `XarrayIndex` subclasses in its API (Dataset/DataArray constructors, `set_index()`, etc.). + +### 2.2 Explicit vs. implicit index creation + +#### 2.2.1 Dataset/DataArray's `indexes` constructor argument + +The new `indexes` argument of Dataset/DataArray constructors may be used to specify which kind of index to bind to which coordinate(s). It would consist of a mapping where, for each item, the key is one coordinate name (or a sequence of coordinate names) that must be given in `coords` and the value is the type of the index to build from this (these) coordinate(s): + +```python +>>> da = xr.DataArray( +... data=[[275.2, 273.5], [270.8, 278.6]], +... dims=('x', 'y'), +... coords={ +... 'lat': (('x', 'y'), [[45.6, 46.5], [50.2, 51.6]]), +... 'lon': (('x', 'y'), [[5.7, 10.5], [6.2, 12.8]]), +... }, +... indexes={('lat', 'lon'): SpatialIndex}, +... ) + +array([[275.2, 273.5], + [270.8, 278.6]]) +Coordinates: + * lat (x, y) float64 45.6 46.5 50.2 51.6 + * lon (x, y) float64 5.7 10.5 6.2 12.8 +``` + +More formally, `indexes` would accept `Mapping[CoordinateNames, IndexSpec]` where: + +- `CoordinateNames = Union[CoordinateName, Tuple[CoordinateName, ...]]` and `CoordinateName = Hashable` +- `IndexSpec = Union[Type[XarrayIndex], Tuple[Type[XarrayIndex], Dict[str, Any]], XarrayIndex]`, so that index instances or index classes + build options could be also passed + +Currently index objects like `pandas.MultiIndex` can be passed directly to `coords`, which in this specific case results in the implicit creation of virtual coordinates. With the new `indexes` argument this behavior may become even more confusing than it currently is. For the sake of clarity, it would be appropriate to eventually drop support for this specific behavior and treat any given mapping value given in `coords` as an array that can be wrapped into an Xarray variable, i.e., in the case of a multi-index: + +```python +>>> xr.DataArray([1.0, 2.0], dims='x', coords={'x': midx}) + +array([1., 2.]) +Coordinates: + x (x) object ('a', 0) ('b', 1) +``` + +A possible, more explicit solution to reuse a `pandas.MultiIndex` in a DataArray/Dataset with levels exposed as coordinates is proposed in [Section 2.2.4](#224-implicit-coordinates). + +#### 2.2.2 Dataset/DataArray's `set_index` method + +New indexes may also be built from existing sets of coordinates or variables in a Dataset/DataArray using the `.set_index()` method. + +The [current signature](http://xarray.pydata.org/en/stable/generated/xarray.DataArray.set_index.html#xarray.DataArray.set_index) of `.set_index()` is tailored to `pandas.MultiIndex` and tied to the concept of a dimension-index. It is therefore hardly reusable as-is in the context of flexible indexes proposed here. + +The new signature may look like one of these: + +- A. `.set_index(coords: CoordinateNames, index: Union[XarrayIndex, Type[XarrayIndex]], **index_kwargs)`: one index is set at a time, index construction options may be passed as keyword arguments +- B. `.set_index(indexes: Mapping[CoordinateNames, Union[Type[XarrayIndex], Tuple[Type[XarrayIndex], Dict[str, Any]]]])`: multiple indexes may be set at a time from a mapping of coordinate or variable name(s) as keys and `XarrayIndex` subclasses (maybe with a dict of build options) as values. If variable names are given as keys of they will be promoted as coordinates + +Option A looks simple and elegant but significantly departs from the current signature. Option B is more consistent with the Dataset/DataArray constructor signature proposed in the previous section and would be easier to adopt in parallel with the current signature that we could still support through some depreciation cycle. + +The `append` parameter of the current `.set_index()` is specific to `pandas.MultiIndex`. With option B we could still support it, although we might want to either drop it or move it to the index construction options in the future. + +#### 2.2.3 Implicit default indexes + +In general explicit index creation should be preferred over implicit index creation. However, there is a majority of cases where basic `pandas.Index` objects could be built and used as indexes for 1-dimensional coordinates. For convenience, Xarray should automatically build such indexes for the coordinates where no index has been explicitly assigned in the Dataset/DataArray constructor or when indexes have been reset / dropped. + +For which coordinates? + +- A. only 1D coordinates with a name matching their dimension name +- B. all 1D coordinates + +When to create it? + +- A. each time when a new Dataset/DataArray is created +- B. only when we need it (i.e., when calling `.sel()` or `indexes`) + +Options A and A are what Xarray currently does and may be the best choice considering that indexes could possibly be invalidated by coordinate mutation. + +Besides `pandas.Index`, other indexes currently supported in Xarray like `CFTimeIndex` could be built depending on the coordinate data type. + +#### 2.2.4 Implicit coordinates + +Like for the indexes, explicit coordinate creation should be preferred over implicit coordinate creation. However, there may be some situations where we would like to keep creating coordinates implicitly for backwards compatibility. + +For example, it is currently possible to pass a `pandas.MulitIndex` object as a coordinate to the Dataset/DataArray constructor: + +```python +>>> midx = pd.MultiIndex.from_arrays([['a', 'b'], [0, 1]], names=['lvl1', 'lvl2']) +>>> da = xr.DataArray([1.0, 2.0], dims='x', coords={'x': midx}) +>>> da + +array([1., 2.]) +Coordinates: + * x (x) MultiIndex + - lvl1 (x) object 'a' 'b' + - lvl2 (x) int64 0 1 +``` + +In that case, virtual coordinates are created for each level of the multi-index. After the index refactoring, these coordinates would become real coordinates bound to the multi-index. + +In the example above a coordinate is also created for the `x` dimension: + +```python +>>> da.x + +array([('a', 0), ('b', 1)], dtype=object) +Coordinates: + * x (x) MultiIndex + - lvl1 (x) object 'a' 'b' + - lvl2 (x) int64 0 1 +``` + +With the new proposed data model, this wouldn't be a requirement anymore: there is no concept of a dimension-index. However, some users might still rely on the `x` coordinate so we could still (temporarily) support it for backwards compatibility. + +Besides `pandas.MultiIndex`, there may be other situations where we would like to reuse an existing index in a new Dataset/DataArray (e.g., when the index is very expensive to build), and which might require implicit creation of one or more coordinates. + +The example given here is quite confusing, though: this is not an easily predictable behavior. We could entirely avoid the implicit creation of coordinates, e.g., using a helper function that generates coordinate + index dictionaries that we could then pass directly to the DataArray/Dataset constructor: + +```python +>>> coords_dict, index_dict = create_coords_from_index(midx, dims='x', include_dim_coord=True) +>>> coords_dict +{'x': + array([('a', 0), ('b', 1)], dtype=object), + 'lvl1': + array(['a', 'b'], dtype=object), + 'lvl2': + array([0, 1])} +>>> index_dict +{('lvl1', 'lvl2'): midx} +>>> xr.DataArray([1.0, 2.0], dims='x', coords=coords_dict, indexes=index_dict) + +array([1., 2.]) +Coordinates: + x (x) object ('a', 0) ('b', 1) + * lvl1 (x) object 'a' 'b' + * lvl2 (x) int64 0 1 +``` + +### 2.2.5 Immutable indexes + +Some underlying indexes might be mutable (e.g., a tree-based index structure that allows dynamic addition of data points) while other indexes like `pandas.Index` aren't. To keep things simple, it is probably better to continue considering all indexes in Xarray as immutable (as well as their corresponding coordinates, see [Section 2.4.1](#241-mutable-coordinates)). + +### 2.3 Index access + +#### 2.3.1 Dataset/DataArray's `indexes` property + +The `indexes` property would allow easy access to all the indexes used in a Dataset/DataArray. It would return a `Dict[CoordinateName, XarrayIndex]` for easy index lookup from coordinate name. + +#### 2.3.2 Additional Dataset/DataArray properties or methods + +In some cases the format returned by the `indexes` property would not be the best (e.g, it may return duplicate index instances as values). For convenience, we could add one more property / method to get the indexes in the desired format if needed. + +### 2.4 Propagate indexes through operations + +#### 2.4.1 Mutable coordinates + +Dataset/DataArray coordinates may be replaced (`__setitem__`) or dropped (`__delitem__`) in-place, which may invalidate some of the indexes. A drastic though probably reasonable solution in this case would be to simply drop all indexes bound to those replaced/dropped coordinates. For the case where a 1D basic coordinate that corresponds to a dimension is added/replaced, we could automatically generate a new index (see [Section 2.2.4](#224-implicit-indexes)). + +We must also ensure that coordinates having a bound index are immutable, e.g., still wrap them into `IndexVariable` objects (even though the `IndexVariable` class might change substantially after this refactoring). + +#### 2.4.2 New Dataset/DataArray with updated coordinates + +Xarray provides a variety of Dataset/DataArray operations affecting the coordinates and where simply dropping the index(es) is not desirable. For example: + +- multi-coordinate indexes could be reduced to single coordinate indexes + - like in `.reset_index()` or `.sel()` applied on a subset of the levels of a `pandas.MultiIndex` and that internally call `MultiIndex.droplevel` and `MultiIndex.get_loc_level`, respectively +- indexes may be indexed themselves + - like `pandas.Index` implements `__getitem__()` + - when indexing their corresponding coordinate(s), e.g., via `.sel()` or `.isel()`, those indexes should be indexed too + - this might not be supported by all Xarray indexes, though +- some indexes that can't be indexed could still be automatically (re)built in the new Dataset/DataArray + - like for example building a new `KDTree` index from the selection of a subset of an initial collection of data points + - this is not always desirable, though, as indexes may be expensive to build + - a more reasonable option would be to explicitly re-build the index, e.g., using `.set_index()` +- Dataset/DataArray operations involving alignment (see [Section 2.6](#26-using-indexes-for-data-alignment)) + +### 2.5 Using indexes for data selection + +One main use of indexes is label-based data selection using the DataArray/Dataset `.sel()` method. This refactoring would introduce a number of API changes that could go through some depreciation cycles: + +- the keys of the mapping given to `indexers` (or the names of `indexer_kwargs`) would not correspond to only dimension names but could be the name of any coordinate that has an index +- for a `pandas.MultiIndex`, if no dimension-coordinate is created by default (see [Section 2.2.4](#224-implicit-coordinates)), providing dict-like objects as indexers should be depreciated +- there should be the possibility to provide additional options to the indexes that support specific selection features (e.g., Scikit-learn's `BallTree`'s `dualtree` query option to boost performance). + - the best API is not trivial here, since `.sel()` may accept indexers passed to several indexes (which should still be supported for convenience and compatibility), and indexes may have similar options with different semantics + - we could introduce a new parameter like `index_options: Dict[XarrayIndex, Dict[str, Any]]` to pass options grouped by index +- the `method` and `tolerance` parameters are specific to `pandas.Index` and would not be supported by all indexes: probably best is to eventually pass those arguments as `index_options` +- the list valid indexer types might be extended in order to support new ways of indexing data, e.g., unordered selection of all points within a given range + - alternatively, we could reuse existing indexer types with different semantics depending on the index, e.g., using `slice(min, max, None)` for unordered range selection + +With the new data model proposed here, an ambiguous situation may occur when indexers are given for several coordinates that share the same dimension but not the same index, e.g., from the example in [Section 1](#1-data-model): + +```python +da.sel(x=..., y=..., lat=..., lon=...) +``` + +The easiest solution for this situation would be to raise an error. Alternatively, we could introduce a new parameter to specify how to combine the resulting integer indexers (i.e., union vs intersection), although this could already be achieved by chaining `.sel()` calls or combining `.sel()` with `.merge()` (it may or may not be straightforward). + +### 2.6 Using indexes for data alignment + +Another main use if indexes is data alignment in various operations. Some considerations regarding alignment and flexible indexes: + +- support for alignment should probably be optional for an `XarrayIndex` subclass. + - like `pandas.Index`, the index wrapper classes that support it should implement `.equals()`, `.union()` and/or `.intersection()` + - support might be partial if that makes sense (outer, inner, left, right, exact...). + - index equality might involve more than just the labels: for example a spatial index might be used to check if the coordinate system (CRS) is identical for two sets of coordinates + - some indexes might implement inexact alignment, like in [#4489](https://github.com/pydata/xarray/pull/4489) or a `KDTree` index that selects nearest-neighbors within a given tolerance + - alignment may be "multi-dimensional", i.e., the `KDTree` example above vs. dimensions aligned independently of each other +- we need to decide what to do when one dimension has more than one index that supports alignment + - we should probably raise unless the user explicitly specify which index to use for the alignment +- we need to decide what to do when one dimension has one or more index(es) but none support alignment + - either we raise or we fail back (silently) to alignment based on dimension size +- for inexact alignment, the tolerance threshold might be given when building the index and/or when performing the alignment +- are there cases where we want a specific index to perform alignment and another index to perform selection? + - it would be tricky to support that unless we allow multiple indexes per coordinate + - alternatively, underlying indexes could be picked internally in a "meta" index for one operation or another, although the risk is to eventually have to deal with an explosion of index wrapper classes with different meta indexes for each combination that we'd like to use. + +### 2.7 Using indexes for other purposes + +Xarray also provides a number of Dataset/DataArray methods where indexes are used in various ways, e.g., + +- `resample` (`CFTimeIndex` and a `DatetimeIntervalIndex`) +- `DatetimeAccessor` & `TimedeltaAccessor` properties (`CFTimeIndex` and a `DatetimeIntervalIndex`) +- `interp` & `interpolate_na`, + - with `IntervalIndex`, these become regridding operations. Should we support hooks for these operations? +- `differentiate`, `integrate`, `polyfit` + - raise an error if not a "simple" 1D index? +- `pad` +- `coarsen` has to make choices about output index labels. +- `sortby` +- `stack`/`unstack` +- plotting + - `plot.pcolormesh` "infers" interval breaks along axes, which are really inferred `bounds` for the appropriate indexes. + - `plot.step` again uses `bounds`. In fact, we may even want `step` to be the default 1D plotting function if the axis has `bounds` attached. + +It would be reasonable to first restrict those methods to the indexes that are currently available in Xarray, and maybe extend the `XarrayIndex` API later upon request when the opportunity arises. + +Conversely, nothing should prevent implementing "non-standard" API in 3rd-party `XarrayIndex` subclasses that could be used in DataArray/Dataset extensions (accessors). For example, we might want to reuse a `KDTree` index to compute k-nearest neighbors (returning a DataArray/Dataset with a new dimension) and/or the distances to the nearest neighbors (returning a DataArray/Dataset with a new data variable). + +### 2.8 Index encoding + +Indexes don't need to be directly serializable since we could (re)build them from their corresponding coordinate(s). However, it would be useful if some indexes could be encoded/decoded to/from a set of arrays that would allow optimized reconstruction and/or storage, e.g., + +- `pandas.MultiIndex` -> `index.levels` and `index.codes` +- Scikit-learn's `KDTree` and `BallTree` that use an array-based representation of an immutable tree structure + +## 3. Index representation in DataArray/Dataset's `repr` + +Since indexes would become 1st class citizen of Xarray's data model, they deserve their own section in Dataset/DataArray `repr` that could look like: + +``` + +array([[5.4, 7.8], + [6.2, 4.7]]) +Coordinates: + * lon (x, y) float64 10.2 15.2 12.6 17.6 + * lat (x, y) float64 40.2 45.6 42.2 47.6 + * x (x) float64 200.0 400.0 + * y (y) float64 800.0 1e+03 +Indexes: + lat, lon + x + y +``` + +To keep the `repr` compact, we could: + +- consolidate entries that map to the same index object, and have an short inline repr for `XarrayIndex` object +- collapse the index section by default in the HTML `repr` +- maybe omit all trivial indexes for 1D coordinates that match the dimension name + +## 4. `IndexVariable` + +`IndexVariable` is currently used to wrap a `pandas.Index` as a variable, which would not be relevant after this refactoring since it is aimed at decoupling indexes and variables. + +We'll probably need to move elsewhere some of the features implemented in `IndexVariable` to: + +- ensure that all coordinates with an index are immutable (see [Section 2.4.1](#241-mutable-coordinates)) + - do not set values directly, do not (re)chunk (even though it may be already chunked), do not load, do not convert to sparse/dense, etc. +- directly reuse index's data when that's possible + - in the case of a `pandas.Index`, it makes little sense to have duplicate data (e.g., as a NumPy array) for its corresponding coordinate +- convert a variable into a `pandas.Index` using `.to_index()` (for backwards compatibility). + +Other `IndexVariable` API like `level_names` and `get_level_variable()` would not useful anymore: it is specific to how we currently deal with `pandas.MultiIndex` and virtual "level" coordinates in Xarray. + +## 5. Chunked coordinates and/or indexers + +We could take opportunity of this refactoring to better leverage chunked coordinates (and/or chunked indexers for data selection). There's two ways to enable it: + +A. support for chunked coordinates is left to the index +B. support for chunked coordinates is index agnostic and is implemented in Xarray + +As an example for B, [xoak](https://github.com/ESM-VFC/xoak) supports building an index for each chunk, which is coupled with a two-step data selection process (cross-index queries + brute force "reduction" look-up). There is an example [here](https://xoak.readthedocs.io/en/latest/examples/dask_support.html). This may be tedious to generalize this to other kinds of operations, though. Xoak's Dask support is rather experimental, not super stable (it's quite hard to control index replication and data transfer between Dask workers with the default settings), and depends on whether indexes are thread-safe and/or serializable. + +Option A may be more reasonable for now. + +## 6. Coordinate duck arrays + +Another opportunity of this refactoring is support for duck arrays as index coordinates. Decoupling coordinates and indexes would *de-facto* enable it. + +However, support for duck arrays in index-based operations such as data selection or alignment would probably require some protocol extension, e.g., + +```python +class MyDuckArray: + ... + + def _sel_(self, indexer): + """Prepare the label-based indexer to conform to this coordinate array.""" + ... + return new_indexer + + ... +``` + +For example, a `pint` array would implement `_sel_` to perform indexer unit conversion or raise, warn, or just pass the indexer through if it has no units. diff --git a/doc/Makefile b/doc/Makefile index d88a8a59c39..8b08d3a2dbe 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -19,6 +19,7 @@ I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . help: @echo "Please use \`make ' where is one of" @echo " html to make standalone HTML files" + @echo " rtdhtml Build html using same settings used on ReadtheDocs" @echo " livehtml Make standalone HTML files and rebuild the documentation when a change is detected. Also includes a livereload enabled web server" @echo " dirhtml to make HTML files named index.html in directories" @echo " singlehtml to make a single large HTML file" @@ -58,6 +59,13 @@ html: @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." +.PHONY: rtdhtml +rtdhtml: + $(SPHINXBUILD) -T -j auto -E -W --keep-going -b html -d $(BUILDDIR)/doctrees -D language=en . $(BUILDDIR)/html + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." + + .PHONY: livehtml livehtml: # @echo "$(SPHINXATUOBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html" diff --git a/doc/_static/dataset-diagram-logo.png b/doc/_static/dataset-diagram-logo.png index dab01949fa2..23c413d3414 100644 Binary files a/doc/_static/dataset-diagram-logo.png and b/doc/_static/dataset-diagram-logo.png differ diff --git a/doc/_static/style.css b/doc/_static/style.css index b7d30f429cf..833b11a83ab 100644 --- a/doc/_static/style.css +++ b/doc/_static/style.css @@ -1,27 +1,263 @@ -@import url("theme.css"); +table.colwidths-given { + table-layout: fixed; + width: 100%; +} +table.docutils td { + white-space: unset; + word-wrap: break-word; +} + +/* Reduce left and right margins */ + +.container, .container-lg, .container-md, .container-sm, .container-xl { + max-width: 1350px !important; +} + + +/* Copied from +https://github.com/bokeh/bokeh/blob/branch-2.4/sphinx/source/bokeh/static/custom.css +*/ + +:root { + /* Logo image height + all the paddings/margins make the navbar height. */ + --navbar-height: calc(30px + 0.3125rem * 2 + 0.5rem * 2); +} + +.bd-search { + position: relative; + padding-bottom: 20px; +} + +@media (min-width: 768px) { + .search-front-page { + width: 50%; + } +} + +/* minimal copy paste from bootstrap docs css to get sidebars working */ + +.bd-toc { + -ms-flex-order: 2; + order: 2; + padding-top: 1.5rem; + padding-bottom: 1.5rem; + /* font-size: 0.875rem; */ + /* add scrolling sidebar */ + height: calc(100vh - 2rem); + overflow-y: auto; +} + +@supports ((position: -webkit-sticky) or (position: sticky)) { + .bd-toc { + position: -webkit-sticky; + position: sticky; + top: 4rem; + height: calc(100vh - 4rem); + overflow-y: auto; + } +} + +.section-nav { + padding-left: 0; + border-left: 1px solid #eee; + border-bottom: none; +} + +.section-nav ul { + padding-left: 1rem; +} + +.toc-entry { + display: block; +} + +.toc-entry a { + display: block; + padding: 0.125rem 1.5rem; + color: #77757a; +} -.wy-side-nav-search>a img.logo, -.wy-side-nav-search .wy-dropdown>a img.logo { - width: 12rem +.toc-entry a:hover { + color: rgba(0, 0, 0, 0.85); + text-decoration: none; } -.wy-side-nav-search { - background-color: #eee; +.bd-sidebar { + -ms-flex-order: 0; + order: 0; + border-bottom: 1px solid rgba(0, 0, 0, 0.1); } -.wy-side-nav-search>div.version { +@media (min-width: 768px) { + .bd-sidebar { + border-right: 1px solid rgba(0, 0, 0, 0.1); + } + @supports ((position: -webkit-sticky) or (position: sticky)) { + .bd-sidebar { + position: -webkit-sticky; + position: sticky; + top: var(--navbar-height); + z-index: 1000; + height: calc(100vh - var(--navbar-height)); + } + } +} + +@media (min-width: 1200px) { + .bd-sidebar { + -ms-flex: 0 1 320px; + flex: 0 1 320px; + } +} + +.bd-links { + padding-top: 1rem; + padding-bottom: 1rem; + margin-right: -15px; + margin-left: -15px; +} + +@media (min-width: 768px) { + @supports ((position: -webkit-sticky) or (position: sticky)) { + .bd-links { + max-height: calc(100vh - 9rem); + overflow-y: auto; + } + } +} + +@media (min-width: 768px) { + .bd-links { + display: block !important; + } +} + +.bd-sidenav { display: none; } -.wy-nav-top { - background-color: #555; +.bd-toc-link { + display: block; + padding: 0.25rem 1.5rem; + font-weight: 400; + color: rgba(0, 0, 0, 0.65); } -table.colwidths-given { - table-layout: fixed; - width: 100%; +.bd-toc-link:hover { + color: rgba(0, 0, 0, 0.85); + text-decoration: none; } -table.docutils td { - white-space: unset; - word-wrap: break-word; + +.bd-toc-item.active { + margin-bottom: 1rem; +} + +.bd-toc-item.active:not(:first-child) { + margin-top: 1rem; +} + +.bd-toc-item.active > .bd-toc-link { + color: rgba(0, 0, 0, 0.85); +} + +.bd-toc-item.active > .bd-toc-link:hover { + background-color: transparent; +} + +.bd-toc-item.active > .bd-sidenav { + display: block; +} + +.bd-sidebar .nav > li > a { + display: block; + padding: 0.25rem 1.5rem; + font-size: 90%; + color: rgba(0, 0, 0, 0.65); +} + +.bd-sidebar .nav > li > a:hover { + color: rgba(0, 0, 0, 0.85); + text-decoration: none; + background-color: transparent; +} + +.bd-sidebar .nav > .active > a, +.bd-sidebar .nav > .active:hover > a { + font-weight: 400; + color: #130654; + /* adjusted from original + color: rgba(0, 0, 0, 0.85); + background-color: transparent; */ +} + +.bd-sidebar .nav > li > ul { + list-style: none; + padding: 0.25rem 1.5rem; +} + +.bd-sidebar .nav > li > ul > li > a { + display: block; + padding: 0.25rem 1.5rem; + font-size: 90%; + color: rgba(0, 0, 0, 0.65); +} + +.bd-sidebar .nav > li > ul > .active > a, +.bd-sidebar .nav > li > ul > .active:hover > a { + font-weight: 400; + color: #542437; +} + +dt:target { + background-color: initial; +} + +/* Offsetting anchored elements within the main content to adjust for fixed header + https://github.com/pandas-dev/pandas-sphinx-theme/issues/6 */ +main *:target::before { + display: block; + content: ''; + height: var(--navbar-height); + margin-top: calc(-1 * var(--navbar-height)); +} + +body { + /* Add padding to body to avoid overlap with navbar. */ + padding-top: var(--navbar-height); + width: 100%; +} + +/* adjust toc font sizes to improve overview */ +.toc-h2 { + font-size: 0.85rem; +} + +.toc-h3 { + font-size: 0.75rem; +} + +.toc-h4 { + font-size: 0.65rem; +} + +.toc-entry > .nav-link.active { + font-weight: 400; + color: #542437; + background-color: transparent; + border-left: 2px solid #563d7c; +} + +.nav-link:hover { + border-style: none; +} + +/* Collapsing of the TOC sidebar while scrolling */ + +/* Nav: hide second level (shown on .active) */ +.bd-toc .nav .nav { + display: none; +} + +.bd-toc .nav > .active > ul { + display: block; } diff --git a/doc/_static/thumbnails/ERA5-GRIB-example.png b/doc/_static/thumbnails/ERA5-GRIB-example.png new file mode 100644 index 00000000000..412dd28a6d9 Binary files /dev/null and b/doc/_static/thumbnails/ERA5-GRIB-example.png differ diff --git a/doc/_static/thumbnails/ROMS_ocean_model.png b/doc/_static/thumbnails/ROMS_ocean_model.png new file mode 100644 index 00000000000..9333335d1ef Binary files /dev/null and b/doc/_static/thumbnails/ROMS_ocean_model.png differ diff --git a/doc/_static/thumbnails/area_weighted_temperature.png b/doc/_static/thumbnails/area_weighted_temperature.png new file mode 100644 index 00000000000..7d3604d7c2b Binary files /dev/null and b/doc/_static/thumbnails/area_weighted_temperature.png differ diff --git a/doc/_static/thumbnails/monthly-means.png b/doc/_static/thumbnails/monthly-means.png new file mode 100644 index 00000000000..da5691848b0 Binary files /dev/null and b/doc/_static/thumbnails/monthly-means.png differ diff --git a/doc/_static/thumbnails/multidimensional-coords.png b/doc/_static/thumbnails/multidimensional-coords.png new file mode 100644 index 00000000000..b0d893d6894 Binary files /dev/null and b/doc/_static/thumbnails/multidimensional-coords.png differ diff --git a/doc/_static/thumbnails/toy-weather-data.png b/doc/_static/thumbnails/toy-weather-data.png new file mode 100644 index 00000000000..64ac0a4b021 Binary files /dev/null and b/doc/_static/thumbnails/toy-weather-data.png differ diff --git a/doc/_static/thumbnails/visualization_gallery.png b/doc/_static/thumbnails/visualization_gallery.png new file mode 100644 index 00000000000..9e6c2436be5 Binary files /dev/null and b/doc/_static/thumbnails/visualization_gallery.png differ diff --git a/doc/_templates/autosummary/base.rst b/doc/_templates/autosummary/base.rst deleted file mode 100644 index 53f2a29c193..00000000000 --- a/doc/_templates/autosummary/base.rst +++ /dev/null @@ -1,3 +0,0 @@ -:github_url: {{ fullname | github_url | escape_underscores }} - -{% extends "!autosummary/base.rst" %} diff --git a/doc/_templates/layout.html b/doc/_templates/layout.html deleted file mode 100644 index 4c57ba83056..00000000000 --- a/doc/_templates/layout.html +++ /dev/null @@ -1,2 +0,0 @@ -{% extends "!layout.html" %} -{% set css_files = css_files + ["_static/style.css"] %} diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index e5492ec73a4..fc27d9c3fe8 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -41,18 +41,19 @@ core.rolling.DatasetCoarsen.all core.rolling.DatasetCoarsen.any + core.rolling.DatasetCoarsen.construct core.rolling.DatasetCoarsen.count core.rolling.DatasetCoarsen.max core.rolling.DatasetCoarsen.mean core.rolling.DatasetCoarsen.median core.rolling.DatasetCoarsen.min core.rolling.DatasetCoarsen.prod + core.rolling.DatasetCoarsen.reduce core.rolling.DatasetCoarsen.std core.rolling.DatasetCoarsen.sum core.rolling.DatasetCoarsen.var core.rolling.DatasetCoarsen.boundary core.rolling.DatasetCoarsen.coord_func - core.rolling.DatasetCoarsen.keep_attrs core.rolling.DatasetCoarsen.obj core.rolling.DatasetCoarsen.side core.rolling.DatasetCoarsen.trim_excess @@ -118,7 +119,6 @@ core.rolling.DatasetRolling.var core.rolling.DatasetRolling.center core.rolling.DatasetRolling.dim - core.rolling.DatasetRolling.keep_attrs core.rolling.DatasetRolling.min_periods core.rolling.DatasetRolling.obj core.rolling.DatasetRolling.rollings @@ -184,18 +184,19 @@ core.rolling.DataArrayCoarsen.all core.rolling.DataArrayCoarsen.any + core.rolling.DataArrayCoarsen.construct core.rolling.DataArrayCoarsen.count core.rolling.DataArrayCoarsen.max core.rolling.DataArrayCoarsen.mean core.rolling.DataArrayCoarsen.median core.rolling.DataArrayCoarsen.min core.rolling.DataArrayCoarsen.prod + core.rolling.DataArrayCoarsen.reduce core.rolling.DataArrayCoarsen.std core.rolling.DataArrayCoarsen.sum core.rolling.DataArrayCoarsen.var core.rolling.DataArrayCoarsen.boundary core.rolling.DataArrayCoarsen.coord_func - core.rolling.DataArrayCoarsen.keep_attrs core.rolling.DataArrayCoarsen.obj core.rolling.DataArrayCoarsen.side core.rolling.DataArrayCoarsen.trim_excess @@ -259,7 +260,6 @@ core.rolling.DataArrayRolling.var core.rolling.DataArrayRolling.center core.rolling.DataArrayRolling.dim - core.rolling.DataArrayRolling.keep_attrs core.rolling.DataArrayRolling.min_periods core.rolling.DataArrayRolling.obj core.rolling.DataArrayRolling.window @@ -285,6 +285,7 @@ core.accessor_dt.DatetimeAccessor.floor core.accessor_dt.DatetimeAccessor.round core.accessor_dt.DatetimeAccessor.strftime + core.accessor_dt.DatetimeAccessor.date core.accessor_dt.DatetimeAccessor.day core.accessor_dt.DatetimeAccessor.dayofweek core.accessor_dt.DatetimeAccessor.dayofyear @@ -322,14 +323,21 @@ core.accessor_dt.TimedeltaAccessor.seconds core.accessor_str.StringAccessor.capitalize + core.accessor_str.StringAccessor.casefold + core.accessor_str.StringAccessor.cat core.accessor_str.StringAccessor.center core.accessor_str.StringAccessor.contains core.accessor_str.StringAccessor.count core.accessor_str.StringAccessor.decode core.accessor_str.StringAccessor.encode core.accessor_str.StringAccessor.endswith + core.accessor_str.StringAccessor.extract + core.accessor_str.StringAccessor.extractall core.accessor_str.StringAccessor.find + core.accessor_str.StringAccessor.findall + core.accessor_str.StringAccessor.format core.accessor_str.StringAccessor.get + core.accessor_str.StringAccessor.get_dummies core.accessor_str.StringAccessor.index core.accessor_str.StringAccessor.isalnum core.accessor_str.StringAccessor.isalpha @@ -340,20 +348,26 @@ core.accessor_str.StringAccessor.isspace core.accessor_str.StringAccessor.istitle core.accessor_str.StringAccessor.isupper + core.accessor_str.StringAccessor.join core.accessor_str.StringAccessor.len core.accessor_str.StringAccessor.ljust core.accessor_str.StringAccessor.lower core.accessor_str.StringAccessor.lstrip core.accessor_str.StringAccessor.match + core.accessor_str.StringAccessor.normalize core.accessor_str.StringAccessor.pad + core.accessor_str.StringAccessor.partition core.accessor_str.StringAccessor.repeat core.accessor_str.StringAccessor.replace core.accessor_str.StringAccessor.rfind core.accessor_str.StringAccessor.rindex core.accessor_str.StringAccessor.rjust + core.accessor_str.StringAccessor.rpartition + core.accessor_str.StringAccessor.rsplit core.accessor_str.StringAccessor.rstrip core.accessor_str.StringAccessor.slice core.accessor_str.StringAccessor.slice_replace + core.accessor_str.StringAccessor.split core.accessor_str.StringAccessor.startswith core.accessor_str.StringAccessor.strip core.accessor_str.StringAccessor.swapcase @@ -581,6 +595,7 @@ plot.imshow plot.pcolormesh plot.scatter + plot.surface plot.FacetGrid.map_dataarray plot.FacetGrid.set_titles @@ -809,3 +824,27 @@ backends.DummyFileManager.acquire backends.DummyFileManager.acquire_context backends.DummyFileManager.close + + backends.BackendArray + backends.BackendEntrypoint.guess_can_open + backends.BackendEntrypoint.open_dataset + + core.indexing.IndexingSupport + core.indexing.explicit_indexing_adapter + core.indexing.BasicIndexer + core.indexing.OuterIndexer + core.indexing.VectorizedIndexer + core.indexing.LazilyIndexedArray + core.indexing.LazilyVectorizedIndexedArray + + conventions.decode_cf_variables + + coding.variables.UnsignedIntegerCoder + coding.variables.CFMaskCoder + coding.variables.CFScaleOffsetCoder + + coding.strings.CharacterArrayCoder + coding.strings.EncodedStringCoder + + coding.times.CFTimedeltaCoder + coding.times.CFDatetimeCoder diff --git a/doc/api.rst b/doc/api.rst index 0bce923f9fc..7bf745ef686 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -24,7 +24,6 @@ Top-level functions combine_by_coords combine_nested where - set_options infer_freq full_like zeros_like @@ -36,6 +35,7 @@ Top-level functions map_blocks show_versions set_options + unify_chunks Dataset ======= @@ -138,6 +138,7 @@ Indexing Dataset.set_index Dataset.reset_index Dataset.reorder_levels + Dataset.query Missing value handling ---------------------- @@ -178,6 +179,7 @@ Computation Dataset.integrate Dataset.map_blocks Dataset.polyfit + Dataset.curvefit **Aggregation**: :py:attr:`~Dataset.all` @@ -241,6 +243,8 @@ Plotting :template: autosummary/accessor_method.rst Dataset.plot.scatter + Dataset.plot.quiver + Dataset.plot.streamplot DataArray ========= @@ -288,6 +292,7 @@ DataArray contents DataArray.swap_dims DataArray.expand_dims DataArray.drop_vars + DataArray.drop_duplicates DataArray.reset_coords DataArray.copy @@ -320,6 +325,7 @@ Indexing DataArray.set_index DataArray.reset_index DataArray.reorder_levels + DataArray.query Missing value handling ---------------------- @@ -371,7 +377,7 @@ Computation DataArray.integrate DataArray.polyfit DataArray.map_blocks - + DataArray.curvefit **Aggregation**: :py:attr:`~DataArray.all` @@ -415,42 +421,62 @@ Computation String manipulation ------------------- +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor.rst + + DataArray.str + .. autosummary:: :toctree: generated/ :template: autosummary/accessor_method.rst DataArray.str.capitalize + DataArray.str.casefold + DataArray.str.cat DataArray.str.center DataArray.str.contains DataArray.str.count DataArray.str.decode DataArray.str.encode DataArray.str.endswith + DataArray.str.extract + DataArray.str.extractall DataArray.str.find + DataArray.str.findall + DataArray.str.format DataArray.str.get + DataArray.str.get_dummies DataArray.str.index DataArray.str.isalnum DataArray.str.isalpha DataArray.str.isdecimal DataArray.str.isdigit + DataArray.str.islower DataArray.str.isnumeric DataArray.str.isspace DataArray.str.istitle DataArray.str.isupper + DataArray.str.join DataArray.str.len DataArray.str.ljust DataArray.str.lower DataArray.str.lstrip DataArray.str.match + DataArray.str.normalize DataArray.str.pad + DataArray.str.partition DataArray.str.repeat DataArray.str.replace DataArray.str.rfind DataArray.str.rindex DataArray.str.rjust + DataArray.str.rpartition + DataArray.str.rsplit DataArray.str.rstrip DataArray.str.slice DataArray.str.slice_replace + DataArray.str.split DataArray.str.startswith DataArray.str.strip DataArray.str.swapcase @@ -486,6 +512,7 @@ Datetimelike properties DataArray.dt.daysinmonth DataArray.dt.season DataArray.dt.time + DataArray.dt.date DataArray.dt.is_month_start DataArray.dt.is_month_end DataArray.dt.is_quarter_end @@ -562,6 +589,7 @@ Plotting DataArray.plot.line DataArray.plot.pcolormesh DataArray.plot.step + DataArray.plot.surface .. _api.ufuncs: @@ -656,6 +684,8 @@ Dataset methods open_rasterio open_zarr Dataset.to_netcdf + Dataset.to_pandas + Dataset.as_numpy Dataset.to_zarr save_mfdataset Dataset.to_array @@ -686,6 +716,8 @@ DataArray methods DataArray.to_pandas DataArray.to_series DataArray.to_dataframe + DataArray.to_numpy + DataArray.as_numpy DataArray.to_index DataArray.to_masked_array DataArray.to_cdms2 @@ -836,6 +868,7 @@ Faceting plot.FacetGrid plot.FacetGrid.add_colorbar plot.FacetGrid.add_legend + plot.FacetGrid.add_quiverkey plot.FacetGrid.map plot.FacetGrid.map_dataarray plot.FacetGrid.map_dataarray_line @@ -853,6 +886,7 @@ Tutorial :toctree: generated/ tutorial.open_dataset + tutorial.open_rasterio tutorial.load_dataset Testing @@ -886,8 +920,12 @@ Advanced API Variable IndexVariable as_variable + Context register_dataset_accessor register_dataarray_accessor + Dataset.set_close + backends.BackendArray + backends.BackendEntrypoint These backends provide a low-level interface for lazily loading data from external file-formats or protocols, and can be manually invoked to create diff --git a/doc/conf.py b/doc/conf.py index 14b28b4e471..0a6d1504161 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -13,14 +13,13 @@ import datetime +import inspect import os -import pathlib import subprocess import sys from contextlib import suppress import sphinx_autosummary_accessors -from jinja2.defaults import DEFAULT_FILTERS import xarray @@ -34,7 +33,7 @@ subprocess.run(["conda", "list"]) else: print("pip environment:") - subprocess.run(["pip", "list"]) + subprocess.run([sys.executable, "-m", "pip", "list"]) print(f"xarray: {xarray.__version__}, {xarray.__file__}") @@ -80,7 +79,11 @@ "IPython.sphinxext.ipython_console_highlighting", "nbsphinx", "sphinx_autosummary_accessors", - "scanpydoc.rtd_github_links", + "sphinx.ext.linkcode", + "sphinx_panels", + "sphinxext.opengraph", + "sphinx_copybutton", + "sphinxext.rediraffe", ] extlinks = { @@ -88,32 +91,30 @@ "pull": ("https://github.com/pydata/xarray/pull/%s", "PR"), } +# sphinx-copybutton configurations +copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: " +copybutton_prompt_is_regexp = True + +# nbsphinx configurations + nbsphinx_timeout = 600 nbsphinx_execute = "always" nbsphinx_prolog = """ {% set docname = env.doc2path(env.docname, base=None) %} -You can run this notebook in a `live session `_ |Binder| or view it `on Github `_. +You can run this notebook in a `live session `_ |Binder| or view it `on Github `_. .. |Binder| image:: https://mybinder.org/badge.svg - :target: https://mybinder.org/v2/gh/pydata/xarray/master?urlpath=lab/tree/doc/{{ docname }} + :target: https://mybinder.org/v2/gh/pydata/xarray/main?urlpath=lab/tree/doc/{{ docname }} """ autosummary_generate = True - -# for scanpydoc's jinja filter -project_dir = pathlib.Path(__file__).parent.parent -html_context = { - "github_user": "pydata", - "github_repo": "xarray", - "github_version": "master", -} - autodoc_typehints = "none" +# Napoleon configurations + napoleon_google_docstring = False napoleon_numpy_docstring = True - napoleon_use_param = False napoleon_use_rtype = False napoleon_preprocess_types = True @@ -124,6 +125,7 @@ "callable": ":py:func:`callable`", "dict_like": ":term:`dict-like `", "dict-like": ":term:`dict-like `", + "path-like": ":term:`path-like `", "mapping": ":term:`mapping`", "file-like": ":term:`file-like `", # special terms @@ -142,7 +144,7 @@ "hashable": ":term:`hashable `", # matplotlib terms "color-like": ":py:func:`color-like `", - "matplotlib colormap name": ":doc:matplotlib colormap name ", + "matplotlib colormap name": ":doc:`matplotlib colormap name `", "matplotlib axes object": ":py:class:`matplotlib axes object `", "colormap": ":py:class:`colormap `", # objects without namespace @@ -167,17 +169,13 @@ "pd.NaT": "~pandas.NaT", } -numpydoc_class_members_toctree = True -numpydoc_show_class_members = False # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates", sphinx_autosummary_accessors.templates_path] # The suffix of source filenames. -source_suffix = ".rst" +# source_suffix = ".rst" -# The encoding of source files. -# source_encoding = 'utf-8-sig' # The master toctree document. master_doc = "index" @@ -186,19 +184,11 @@ project = "xarray" copyright = "2014-%s, xarray Developers" % datetime.datetime.now().year -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# # The short X.Y version. version = xarray.__version__.split("+")[0] # The full version, including alpha/beta/rc tags. release = xarray.__version__ -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -# language = None - # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: # today = '' @@ -209,51 +199,45 @@ # directories to ignore when looking for source files. exclude_patterns = ["_build", "**.ipynb_checkpoints"] -# The reST default role (used for this markup: `text`) to use for all -# documents. -# default_role = None - -# If true, '()' will be appended to :func: etc. cross-reference text. -# add_function_parentheses = True - -# If true, the current module name will be prepended to all description -# unit titles (such as .. function::). -# add_module_names = True - -# If true, sectionauthor and moduleauthor directives will be shown in the -# output. They are ignored by default. -# show_authors = False # The name of the Pygments (syntax highlighting) style to use. pygments_style = "sphinx" -# A list of ignored prefixes for module index sorting. -# modindex_common_prefix = [] - -# If true, keep warnings as "system message" paragraphs in the built documents. -# keep_warnings = False - # -- Options for HTML output ---------------------------------------------- - # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = "sphinx_rtd_theme" +html_theme = "sphinx_book_theme" +html_title = "" + +html_context = { + "github_user": "pydata", + "github_repo": "xarray", + "github_version": "main", + "doc_path": "doc", +} # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -html_theme_options = {"logo_only": True} - -# Add any paths that contain custom themes here, relative to this directory. -# html_theme_path = [] - -# The name for this set of Sphinx documents. If None, it defaults to -# " v documentation". -# html_title = None +html_theme_options = dict( + # analytics_id='' this is configured in rtfd.io + # canonical_url="", + repository_url="https://github.com/pydata/xarray", + repository_branch="main", + path_to_docs="doc", + use_edit_page_button=True, + use_repository_button=True, + use_issues_button=True, + home_page_in_toc=False, + extra_navbar="", + navbar_footer_text="", + extra_footer="""

Xarray is a fiscally sponsored project of NumFOCUS, + a nonprofit dedicated to supporting the open-source scientific computing community.
+ Theme by the Executable Book Project

""", + twitter_url="https://twitter.com/xarray_devs", +) -# A shorter title for the navigation bar. Default is the same as html_title. -# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. @@ -268,6 +252,42 @@ # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] +html_css_files = ["style.css"] + + +# configuration for sphinxext.opengraph +ogp_site_url = "https://xarray.pydata.org/en/latest/" +ogp_image = "https://xarray.pydata.org/en/stable/_static/dataset-diagram-logo.png" +ogp_custom_meta_tags = [ + '', + '', +] + +# Redirects for pages that were moved to new locations + +rediraffe_redirects = { + "terminology.rst": "user-guide/terminology.rst", + "data-structures.rst": "user-guide/data-structures.rst", + "indexing.rst": "user-guide/indexing.rst", + "interpolation.rst": "user-guide/interpolation.rst", + "computation.rst": "user-guide/computation.rst", + "groupby.rst": "user-guide/groupby.rst", + "reshaping.rst": "user-guide/reshaping.rst", + "combining.rst": "user-guide/combining.rst", + "time-series.rst": "user-guide/time-series.rst", + "weather-climate.rst": "user-guide/weather-climate.rst", + "pandas.rst": "user-guide/pandas.rst", + "io.rst": "user-guide/io.rst", + "dask.rst": "user-guide/dask.rst", + "plotting.rst": "user-guide/plotting.rst", + "duckarrays.rst": "user-guide/duckarrays.rst", + "related-projects.rst": "ecosystem.rst", + "faq.rst": "getting-started-guide/faq.rst", + "why-xarray.rst": "getting-started-guide/why-xarray.rst", + "installing.rst": "getting-started-guide/installing.rst", + "quick-overview.rst": "getting-started-guide/quick-overview.rst", +} # Sometimes the savefig directory doesn't exist and needs to be created # https://github.com/ipython/ipython/issues/8733 @@ -278,144 +298,24 @@ if not os.path.exists(ipython_savefig_dir): os.makedirs(ipython_savefig_dir) -# Add any extra paths that contain custom files (such as robots.txt or -# .htaccess) here, relative to this directory. These files are copied -# directly to the root of the documentation. -# html_extra_path = [] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. html_last_updated_fmt = today_fmt -# If true, SmartyPants will be used to convert quotes and dashes to -# typographically correct entities. -# html_use_smartypants = True - -# Custom sidebar templates, maps document names to template names. -# html_sidebars = {} - -# Additional templates that should be rendered to pages, maps page names to -# template names. -# html_additional_pages = {} - -# If false, no module index is generated. -# html_domain_indices = True - -# If false, no index is generated. -# html_use_index = True - -# If true, the index is split into individual pages for each letter. -# html_split_index = False - -# If true, links to the reST sources are added to the pages. -# html_show_sourcelink = True - -# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -# html_show_sphinx = True - -# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -# html_show_copyright = True - -# If true, an OpenSearch description file will be output, and all pages will -# contain a tag referring to it. The value of this option must be the -# base URL from which the finished HTML is served. -# html_use_opensearch = '' - -# This is the file name suffix for HTML files (e.g. ".xhtml"). -# html_file_suffix = None - # Output file base name for HTML help builder. htmlhelp_basename = "xarraydoc" -# -- Options for LaTeX output --------------------------------------------- - -# latex_elements = { -# # The paper size ('letterpaper' or 'a4paper'). -# # 'papersize': 'letterpaper', -# # The font size ('10pt', '11pt' or '12pt'). -# # 'pointsize': '10pt', -# # Additional stuff for the LaTeX preamble. -# # 'preamble': '', -# } - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, -# author, documentclass [howto, manual, or own class]). -# latex_documents = [ -# ("index", "xarray.tex", "xarray Documentation", "xarray Developers", "manual") -# ] - -# The name of an image file (relative to this directory) to place at the top of -# the title page. -# latex_logo = None - -# For "manual" documents, if this is true, then toplevel headings are parts, -# not chapters. -# latex_use_parts = False - -# If true, show page references after internal links. -# latex_show_pagerefs = False - -# If true, show URL addresses after external links. -# latex_show_urls = False - -# Documents to append as an appendix to all manuals. -# latex_appendices = [] - -# If false, no module index is generated. -# latex_domain_indices = True - - -# -- Options for manual page output --------------------------------------- - -# One entry per manual page. List of tuples -# (source start file, name, description, authors, manual section). -# man_pages = [("index", "xarray", "xarray Documentation", ["xarray Developers"], 1)] - -# If true, show URL addresses after external links. -# man_show_urls = False - - -# -- Options for Texinfo output ------------------------------------------- - -# Grouping the document tree into Texinfo files. List of tuples -# (source start file, target name, title, author, -# dir menu entry, description, category) -# texinfo_documents = [ -# ( -# "index", -# "xarray", -# "xarray Documentation", -# "xarray Developers", -# "xarray", -# "N-D labeled arrays and datasets in Python.", -# "Miscellaneous", -# ) -# ] - -# Documents to append as an appendix to all manuals. -# texinfo_appendices = [] - -# If false, no module index is generated. -# texinfo_domain_indices = True - -# How to display URL addresses: 'footnote', 'no', or 'inline'. -# texinfo_show_urls = 'footnote' - -# If true, do not generate a @detailmenu in the "Top" node's menu. -# texinfo_no_detailmenu = False - - # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { "python": ("https://docs.python.org/3/", None), "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), "iris": ("https://scitools-iris.readthedocs.io/en/latest", None), "numpy": ("https://numpy.org/doc/stable", None), - "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), + "scipy": ("https://docs.scipy.org/doc/scipy", None), "numba": ("https://numba.pydata.org/numba-doc/latest", None), - "matplotlib": ("https://matplotlib.org", None), + "matplotlib": ("https://matplotlib.org/stable/", None), "dask": ("https://docs.dask.org/en/latest", None), "cftime": ("https://unidata.github.io/cftime", None), "rasterio": ("https://rasterio.readthedocs.io/en/latest", None), @@ -423,9 +323,61 @@ } -def escape_underscores(string): - return string.replace("_", r"\_") +# based on numpy doc/source/conf.py +def linkcode_resolve(domain, info): + """ + Determine the URL corresponding to Python object + """ + if domain != "py": + return None + + modname = info["module"] + fullname = info["fullname"] + + submod = sys.modules.get(modname) + if submod is None: + return None + + obj = submod + for part in fullname.split("."): + try: + obj = getattr(obj, part) + except AttributeError: + return None + + try: + fn = inspect.getsourcefile(inspect.unwrap(obj)) + except TypeError: + fn = None + if not fn: + return None + + try: + source, lineno = inspect.getsourcelines(obj) + except OSError: + lineno = None + + if lineno: + linespec = f"#L{lineno}-L{lineno + len(source) - 1}" + else: + linespec = "" + + fn = os.path.relpath(fn, start=os.path.dirname(xarray.__file__)) + + if "+" in xarray.__version__: + return f"https://github.com/pydata/xarray/blob/main/xarray/{fn}{linespec}" + else: + return ( + f"https://github.com/pydata/xarray/blob/" + f"v{xarray.__version__}/xarray/{fn}{linespec}" + ) + + +def html_page_context(app, pagename, templatename, context, doctree): + # Disable edit button for docstring generated pages + if "generated" in pagename: + context["theme_use_edit_page_button"] = False def setup(app): - DEFAULT_FILTERS["escape_underscores"] = escape_underscores + app.connect("html-page-context", html_page_context) diff --git a/doc/contributing.rst b/doc/contributing.rst index 439791cbbd6..d73d18d5df7 100644 --- a/doc/contributing.rst +++ b/doc/contributing.rst @@ -4,8 +4,6 @@ Contributing to xarray ********************** -.. contents:: Table of contents: - :local: .. note:: @@ -40,7 +38,7 @@ report will allow others to reproduce the bug and provide insight into fixing. S `this stackoverflow article `_ for tips on writing a good bug report. -Trying out the bug-producing code on the *master* branch is often a worthwhile exercise +Trying out the bug-producing code on the *main* branch is often a worthwhile exercise to confirm that the bug still exists. It is also worth searching existing bug reports and pull requests to see if the issue has already been reported and/or fixed. @@ -52,7 +50,7 @@ Bug reports must: ```python import xarray as xr - df = xr.Dataset(...) + ds = xr.Dataset(...) ... ``` @@ -60,8 +58,12 @@ Bug reports must: #. Include the full version string of *xarray* and its dependencies. You can use the built in function:: - >>> import xarray as xr - >>> xr.show_versions() + ```python + import xarray as xr + xr.show_versions() + + ... + ``` #. Explain why the current behavior is wrong/not desired and what you expect instead. @@ -95,7 +97,7 @@ Some great resources for learning Git: * the `GitHub help pages `_. * the `NumPy's documentation `_. -* Matthew Brett's `Pydagogue `_. +* Matthew Brett's `Pydagogue `_. Getting started with Git ------------------------ @@ -194,7 +196,7 @@ See the full conda docs `here `__. Creating a branch ----------------- -You want your master branch to reflect only production-ready code, so create a +You want your ``main`` branch to reflect only production-ready code, so create a feature branch before making your changes. For example:: git branch shiny-new-feature @@ -209,12 +211,12 @@ changes in this branch specific to one bug or feature so it is clear what the branch brings to *xarray*. You can have many "shiny-new-features" and switch in between them using the ``git checkout`` command. -To update this branch, you need to retrieve the changes from the master branch:: +To update this branch, you need to retrieve the changes from the ``main`` branch:: git fetch upstream - git merge upstream/master + git merge upstream/main -This will combine your commits with the latest *xarray* git master. If this +This will combine your commits with the latest *xarray* git ``main``. If this leads to merge conflicts, you must resolve these before submitting your pull request. If you have uncommitted changes, you will need to ``git stash`` them prior to updating. This will effectively store your changes, which can be @@ -381,10 +383,34 @@ with ``git commit --no-verify``. Backwards Compatibility ~~~~~~~~~~~~~~~~~~~~~~~ -Please try to maintain backward compatibility. *xarray* has growing number of users with +Please try to maintain backwards compatibility. *xarray* has a growing number of users with lots of existing code, so don't break it if at all possible. If you think breakage is -required, clearly state why as part of the pull request. Also, be careful when changing -method signatures and add deprecation warnings where needed. +required, clearly state why as part of the pull request. + +Be especially careful when changing function and method signatures, because any change +may require a deprecation warning. For example, if your pull request means that the +argument ``old_arg`` to ``func`` is no longer valid, instead of simply raising an error if +a user passes ``old_arg``, we would instead catch it: + +.. code-block:: python + + def func(new_arg, old_arg=None): + if old_arg is not None: + from warnings import warn + + warn( + "`old_arg` has been deprecated, and in the future will raise an error." + "Please use `new_arg` from now on.", + DeprecationWarning, + ) + + # Still do what the user intended here + +This temporary check would then be removed in a subsequent version of xarray. +This process of first warning users before actually breaking their code is known as a +"deprecation cycle", and makes changes significantly easier to handle both for users +of xarray, and for developers of other libraries that depend on xarray. + .. _contributing.ci: @@ -637,14 +663,14 @@ To install asv:: If you need to run a benchmark, change your directory to ``asv_bench/`` and run:: - asv continuous -f 1.1 upstream/master HEAD + asv continuous -f 1.1 upstream/main HEAD You can replace ``HEAD`` with the name of the branch you are working on, and report benchmarks that changed by more than 10%. The command uses ``conda`` by default for creating the benchmark environments. If you want to use virtualenv instead, write:: - asv continuous -f 1.1 -E virtualenv upstream/master HEAD + asv continuous -f 1.1 -E virtualenv upstream/main HEAD The ``-E virtualenv`` option should be added to all ``asv`` commands that run benchmarks. The default value is defined in ``asv.conf.json``. @@ -656,12 +682,12 @@ regressions. You can run specific benchmarks using the ``-b`` flag, which takes a regular expression. For example, this will only run tests from a ``xarray/asv_bench/benchmarks/groupby.py`` file:: - asv continuous -f 1.1 upstream/master HEAD -b ^groupby + asv continuous -f 1.1 upstream/main HEAD -b ^groupby If you want to only run a specific group of tests from a file, you can do it using ``.`` as a separator. For example:: - asv continuous -f 1.1 upstream/master HEAD -b groupby.GroupByMethods + asv continuous -f 1.1 upstream/main HEAD -b groupby.GroupByMethods will only run the ``GroupByMethods`` benchmark defined in ``groupby.py``. @@ -686,8 +712,12 @@ This will display stderr from the benchmarks, and use your local Information on how to write a benchmark and how to use asv can be found in the `asv documentation `_. -The *xarray* benchmarking suite is run remotely and the results are -available `here `_. +.. + TODO: uncomment once we have a working setup + see https://github.com/pydata/xarray/pull/5066 + + The *xarray* benchmarking suite is run remotely and the results are + available `here `_. Documenting your code --------------------- @@ -773,7 +803,7 @@ double check your branch changes against the branch it was based on: #. Navigate to your repository on GitHub -- https://github.com/your-user-name/xarray #. Click on ``Branches`` #. Click on the ``Compare`` button for your feature branch -#. Select the ``base`` and ``compare`` branches, if necessary. This will be ``master`` and +#. Select the ``base`` and ``compare`` branches, if necessary. This will be ``main`` and ``shiny-new-feature``, respectively. Finally, make the pull request @@ -781,8 +811,8 @@ Finally, make the pull request If everything looks good, you are ready to make a pull request. A pull request is how code from a local repository becomes available to the GitHub community and can be looked -at and eventually merged into the master version. This pull request and its associated -changes will eventually be committed to the master branch and available in the next +at and eventually merged into the ``main`` version. This pull request and its associated +changes will eventually be committed to the ``main`` branch and available in the next release. To submit a pull request: #. Navigate to your repository on GitHub @@ -807,11 +837,11 @@ Delete your merged branch (optional) ------------------------------------ Once your feature branch is accepted into upstream, you'll probably want to get rid of -the branch. First, update your ``master`` branch to check that the merge was successful:: +the branch. First, update your ``main`` branch to check that the merge was successful:: git fetch upstream - git checkout master - git merge upstream/master + git checkout main + git merge upstream/main Then you can do:: diff --git a/doc/related-projects.rst b/doc/ecosystem.rst similarity index 99% rename from doc/related-projects.rst rename to doc/ecosystem.rst index 0a010195d6d..01f5c29b9f5 100644 --- a/doc/related-projects.rst +++ b/doc/ecosystem.rst @@ -1,4 +1,4 @@ -.. _related-projects: +.. _ecosystem: Xarray related projects ----------------------- diff --git a/doc/examples.rst b/doc/examples.rst deleted file mode 100644 index 102138b6e4e..00000000000 --- a/doc/examples.rst +++ /dev/null @@ -1,29 +0,0 @@ -Examples -======== - -.. toctree:: - :maxdepth: 1 - - examples/weather-data - examples/monthly-means - examples/area_weighted_temperature - examples/multidimensional-coords - examples/visualization_gallery - examples/ROMS_ocean_model - examples/ERA5-GRIB-example - -Using apply_ufunc ------------------- -.. toctree:: - :maxdepth: 1 - - examples/apply_ufunc_vectorize_1d - -External Examples ------------------ -.. toctree:: - :maxdepth: 2 - - Managing raster data with rioxarray - Xarray with dask - Xarray and dask on the cloud with Pangeo diff --git a/doc/examples/ERA5-GRIB-example.ipynb b/doc/examples/ERA5-GRIB-example.ipynb index b82a07a64e6..1c6be5f6634 100644 --- a/doc/examples/ERA5-GRIB-example.ipynb +++ b/doc/examples/ERA5-GRIB-example.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "GRIB format is commonly used to disemminate atmospheric model data. With Xarray and the cfgrib engine, GRIB data can easily be analyzed and visualized." + "GRIB format is commonly used to disseminate atmospheric model data. With Xarray and the cfgrib engine, GRIB data can easily be analyzed and visualized." ] }, { diff --git a/doc/examples/ROMS_ocean_model.ipynb b/doc/examples/ROMS_ocean_model.ipynb index 74536bbe28f..b699c4d5ba9 100644 --- a/doc/examples/ROMS_ocean_model.ipynb +++ b/doc/examples/ROMS_ocean_model.ipynb @@ -120,7 +120,7 @@ "source": [ "### A naive vertical slice\n", "\n", - "Create a slice using the s-coordinate as the vertical dimension is typically not very informative." + "Creating a slice using the s-coordinate as the vertical dimension is typically not very informative." ] }, { diff --git a/doc/examples/apply_ufunc_vectorize_1d.ipynb b/doc/examples/apply_ufunc_vectorize_1d.ipynb index a79a4868b63..e9a48d70173 100644 --- a/doc/examples/apply_ufunc_vectorize_1d.ipynb +++ b/doc/examples/apply_ufunc_vectorize_1d.ipynb @@ -494,7 +494,7 @@ "source": [ "So far our function can only handle numpy arrays. A real benefit of `apply_ufunc` is the ability to easily parallelize over dask chunks _when needed_. \n", "\n", - "We want to apply this function in a vectorized fashion over each chunk of the dask array. This is possible using dask's `blockwise` or `map_blocks`. `apply_ufunc` wraps `blockwise` and asking it to map the function over chunks using `blockwise` is as simple as specifying `dask=\"parallelized\"`. With this level of flexibility we need to provide dask with some extra information: \n", + "We want to apply this function in a vectorized fashion over each chunk of the dask array. This is possible using dask's `blockwise`, `map_blocks`, or `apply_gufunc`. Xarray's `apply_ufunc` wraps dask's `apply_gufunc` and asking it to map the function over chunks using `apply_gufunc` is as simple as specifying `dask=\"parallelized\"`. With this level of flexibility we need to provide dask with some extra information: \n", " 1. `output_dtypes`: dtypes of all returned objects, and \n", " 2. `output_sizes`: lengths of any new dimensions. \n", " \n", @@ -711,7 +711,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.8.10" }, "nbsphinx": { "allow_errors": true @@ -732,5 +732,5 @@ } }, "nbformat": 4, - "nbformat_minor": 1 + "nbformat_minor": 4 } diff --git a/doc/examples/area_weighted_temperature.ipynb b/doc/examples/area_weighted_temperature.ipynb index de705966583..7299b50b1b3 100644 --- a/doc/examples/area_weighted_temperature.ipynb +++ b/doc/examples/area_weighted_temperature.ipynb @@ -20,7 +20,7 @@ "Author: [Mathias Hauser](https://github.com/mathause/)\n", "\n", "\n", - "We use the `air_temperature` example dataset to calculate the area-weighted temperature over its domain. This dataset has a regular latitude/ longitude grid, thus the gridcell area decreases towards the pole. For this grid we can use the cosine of the latitude as proxy for the grid cell area.\n" + "We use the `air_temperature` example dataset to calculate the area-weighted temperature over its domain. This dataset has a regular latitude/ longitude grid, thus the grid cell area decreases towards the pole. For this grid we can use the cosine of the latitude as proxy for the grid cell area.\n" ] }, { diff --git a/doc/examples/monthly-means.ipynb b/doc/examples/monthly-means.ipynb index bc88f4a9fc9..3490fc9a4fe 100644 --- a/doc/examples/monthly-means.ipynb +++ b/doc/examples/monthly-means.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Calculating Seasonal Averages from Timeseries of Monthly Means \n", + "Calculating Seasonal Averages from Time Series of Monthly Means \n", "=====\n", "\n", "Author: [Joe Hamman](https://github.com/jhamman/)\n", @@ -60,10 +60,10 @@ "source": [ "#### Now for the heavy lifting:\n", "We first have to come up with the weights,\n", - "- calculate the month lengths for each monthly data record\n", + "- calculate the month length for each monthly data record\n", "- calculate weights using `groupby('time.season')`\n", "\n", - "Finally, we just need to multiply our weights by the `Dataset` and sum allong the time dimension. Creating a `DataArray` for the month length is as easy as using the `days_in_month` accessor on the time coordinate. The calendar type, in this case `'noleap'`, is automatically considered in this operation." + "Finally, we just need to multiply our weights by the `Dataset` and sum along the time dimension. Creating a `DataArray` for the month length is as easy as using the `days_in_month` accessor on the time coordinate. The calendar type, in this case `'noleap'`, is automatically considered in this operation." ] }, { diff --git a/doc/examples/visualization_gallery.ipynb b/doc/examples/visualization_gallery.ipynb index f8d5b1ae458..3f2973dbdb4 100644 --- a/doc/examples/visualization_gallery.ipynb +++ b/doc/examples/visualization_gallery.ipynb @@ -209,8 +209,7 @@ "metadata": {}, "outputs": [], "source": [ - "url = 'https://github.com/mapbox/rasterio/raw/master/tests/data/RGB.byte.tif'\n", - "da = xr.open_rasterio(url)\n", + "da = xr.tutorial.open_rasterio(\"RGB.byte\")\n", "\n", "# The data is in UTM projection. We have to set it manually until\n", "# https://github.com/SciTools/cartopy/issues/813 is implemented\n", @@ -246,8 +245,7 @@ "from rasterio.warp import transform\n", "import numpy as np\n", "\n", - "url = 'https://github.com/mapbox/rasterio/raw/master/tests/data/RGB.byte.tif'\n", - "da = xr.open_rasterio(url)\n", + "da = xr.tutorial.open_rasterio(\"RGB.byte\")\n", "\n", "# Compute the lon/lat coordinates with rasterio.warp.transform\n", "ny, nx = len(da['y']), len(da['x'])\n", diff --git a/doc/gallery.rst b/doc/gallery.rst new file mode 100644 index 00000000000..9e5284cc2ee --- /dev/null +++ b/doc/gallery.rst @@ -0,0 +1,130 @@ +Gallery +======= + +Here's a list of examples on how to use xarray. We will be adding more examples soon. +Contributions are highly welcomed and appreciated. So, if you are interested in contributing, please consult the +:doc:`contributing` guide. + + + +Notebook Examples +----------------- + +.. panels:: + :column: text-center col-lg-6 col-md-6 col-sm-12 col-xs-12 p-2 + :card: +my-2 + :img-top-cls: w-75 m-auto p-2 + :body: d-none + + --- + :img-top: _static/thumbnails/toy-weather-data.png + ++++ + .. link-button:: examples/weather-data + :type: ref + :text: Toy weather data + :classes: btn-outline-dark btn-block stretched-link + + --- + :img-top: _static/thumbnails/monthly-means.png + ++++ + .. link-button:: examples/monthly-means + :type: ref + :text: Calculating Seasonal Averages from Timeseries of Monthly Means + :classes: btn-outline-dark btn-block stretched-link + + --- + :img-top: _static/thumbnails/area_weighted_temperature.png + ++++ + .. link-button:: examples/area_weighted_temperature + :type: ref + :text: Compare weighted and unweighted mean temperature + :classes: btn-outline-dark btn-block stretched-link + + --- + :img-top: _static/thumbnails/multidimensional-coords.png + ++++ + .. link-button:: examples/multidimensional-coords + :type: ref + :text: Working with Multidimensional Coordinates + :classes: btn-outline-dark btn-block stretched-link + + --- + :img-top: _static/thumbnails/visualization_gallery.png + ++++ + .. link-button:: examples/visualization_gallery + :type: ref + :text: Visualization Gallery + :classes: btn-outline-dark btn-block stretched-link + + --- + :img-top: _static/thumbnails/ROMS_ocean_model.png + ++++ + .. link-button:: examples/ROMS_ocean_model + :type: ref + :text: ROMS Ocean Model Example + :classes: btn-outline-dark btn-block stretched-link + + --- + :img-top: _static/thumbnails/ERA5-GRIB-example.png + ++++ + .. link-button:: examples/ERA5-GRIB-example + :type: ref + :text: GRIB Data Example + :classes: btn-outline-dark btn-block stretched-link + + --- + :img-top: _static/dataset-diagram-square-logo.png + ++++ + .. link-button:: examples/apply_ufunc_vectorize_1d + :type: ref + :text: Applying unvectorized functions with apply_ufunc + :classes: btn-outline-dark btn-block stretched-link + + +.. toctree:: + :maxdepth: 1 + :hidden: + + examples/weather-data + examples/monthly-means + examples/area_weighted_temperature + examples/multidimensional-coords + examples/visualization_gallery + examples/ROMS_ocean_model + examples/ERA5-GRIB-example + examples/apply_ufunc_vectorize_1d + + +External Examples +----------------- + + +.. panels:: + :column: text-center col-lg-6 col-md-6 col-sm-12 col-xs-12 p-2 + :card: +my-2 + :img-top-cls: w-75 m-auto p-2 + :body: d-none + + --- + :img-top: _static/dataset-diagram-square-logo.png + ++++ + .. link-button:: https://corteva.github.io/rioxarray/stable/examples/examples.html + :type: url + :text: Managing raster data with rioxarray + :classes: btn-outline-dark btn-block stretched-link + + --- + :img-top: https://avatars.githubusercontent.com/u/60833341?s=200&v=4 + ++++ + .. link-button:: http://gallery.pangeo.io/ + :type: url + :text: Xarray and dask on the cloud with Pangeo + :classes: btn-outline-dark btn-block stretched-link + + --- + :img-top: _static/dataset-diagram-square-logo.png + ++++ + .. link-button:: https://examples.dask.org/xarray.html + :type: url + :text: Xarray with Dask Arrays + :classes: btn-outline-dark btn-block stretched-link diff --git a/doc/faq.rst b/doc/getting-started-guide/faq.rst similarity index 99% rename from doc/faq.rst rename to doc/getting-started-guide/faq.rst index a2151cc4b37..4cf3cc5b63d 100644 --- a/doc/faq.rst +++ b/doc/getting-started-guide/faq.rst @@ -185,7 +185,7 @@ for more details. What other projects leverage xarray? ------------------------------------ -See section :ref:`related-projects`. +See section :ref:`ecosystem`. How should I cite xarray? ------------------------- diff --git a/doc/getting-started-guide/index.rst b/doc/getting-started-guide/index.rst new file mode 100644 index 00000000000..20fd49fb2c4 --- /dev/null +++ b/doc/getting-started-guide/index.rst @@ -0,0 +1,15 @@ +################ +Getting Started +################ + +The getting started guide aims to get you using xarray productively as quickly as possible. +It is designed as an entry point for new users, and it provided an introduction to xarray's main concepts. + +.. toctree:: + :maxdepth: 2 + :hidden: + + why-xarray + installing + quick-overview + faq diff --git a/doc/installing.rst b/doc/getting-started-guide/installing.rst similarity index 78% rename from doc/installing.rst rename to doc/getting-started-guide/installing.rst index 99b8b621aed..506236f3b9a 100644 --- a/doc/installing.rst +++ b/doc/getting-started-guide/installing.rst @@ -8,8 +8,8 @@ Required dependencies - Python (3.7 or later) - setuptools (40.4 or later) -- `numpy `__ (1.15 or later) -- `pandas `__ (0.25 or later) +- `numpy `__ (1.17 or later) +- `pandas `__ (1.0 or later) .. _optional-dependencies: @@ -77,14 +77,6 @@ Alternative data containers ~~~~~~~~~~~~~~~~~~~~~~~~~~~ - `sparse `_: for sparse arrays - `pint `_: for units of measure - - .. note:: - - At the moment of writing, xarray requires a `highly experimental version of pint - `_ (install with - ``pip install git+https://github.com/andrewgsavage/pint.git@refs/pull/6/head)``. - Even with it, interaction with non-numpy array libraries, e.g. dask or sparse, is broken. - - Any numpy-like objects that support `NEP-18 `_. Note that while such libraries theoretically should work, they are untested. @@ -98,12 +90,12 @@ Minimum dependency versions xarray adopts a rolling policy regarding the minimum supported version of its dependencies: -- **Python:** 42 months +- **Python:** 24 months (`NEP-29 `_) - **setuptools:** 42 months (but no older than 40.4) -- **numpy:** 24 months +- **numpy:** 18 months (`NEP-29 `_) -- **dask and dask.distributed:** 12 months (but no older than 2.9) +- **dask and dask.distributed:** 12 months - **sparse, pint** and other libraries that rely on `NEP-18 `_ for integration: very latest available versions only, until the technology will have @@ -111,16 +103,16 @@ dependencies: numpy >=1.17. - **all other libraries:** 12 months -The above should be interpreted as *the minor version (X.Y) initially published no more -than N months ago*. Patch versions (x.y.Z) are not pinned, and only the latest available -at the moment of publishing the xarray release is guaranteed to work. +This means the latest minor (X.Y) version from N months prior. Patch versions (x.y.Z) +are not pinned, and only the latest available at the moment of publishing the xarray +release is guaranteed to work. You can see the actual minimum tested versions: - `For NEP-18 libraries - `_ + `_ - `For everything else - `_ + `_ .. _installation-instructions: @@ -144,15 +136,15 @@ being updated in the default channel. If you don't use conda, be sure you have the required dependencies (numpy and pandas) installed first. Then, install xarray with pip:: - $ pip install xarray + $ python -m pip install xarray We also maintain other dependency sets for different subsets of functionality:: - $ pip install "xarray[io]" # Install optional dependencies for handling I/O - $ pip install "xarray[accel]" # Install optional dependencies for accelerating xarray - $ pip install "xarray[parallel]" # Install optional dependencies for dask arrays - $ pip install "xarray[viz]" # Install optional dependencies for visualization - $ pip install "xarray[complete]" # Install all the above + $ python -m pip install "xarray[io]" # Install optional dependencies for handling I/O + $ python -m pip install "xarray[accel]" # Install optional dependencies for accelerating xarray + $ python -m pip install "xarray[parallel]" # Install optional dependencies for dask arrays + $ python -m pip install "xarray[viz]" # Install optional dependencies for visualization + $ python -m pip install "xarray[complete]" # Install all the above The above commands should install most of the `optional dependencies`_. However, some packages which are either not listed on PyPI or require extra @@ -160,7 +152,7 @@ installation steps are excluded. To know which dependencies would be installed, take a look at the ``[options.extras_require]`` section in ``setup.cfg``: -.. literalinclude:: ../setup.cfg +.. literalinclude:: ../../setup.cfg :language: ini :start-at: [options.extras_require] :end-before: [options.package_data] @@ -177,8 +169,12 @@ repository. Performance Monitoring ~~~~~~~~~~~~~~~~~~~~~~ -A fixed-point performance monitoring of (a part of) our codes can be seen on -`this page `__. +.. + TODO: uncomment once we have a working setup + see https://github.com/pydata/xarray/pull/5066 + + A fixed-point performance monitoring of (a part of) our code can be seen on + `this page `__. To run these benchmark tests in a local machine, first install diff --git a/doc/quick-overview.rst b/doc/getting-started-guide/quick-overview.rst similarity index 99% rename from doc/quick-overview.rst rename to doc/getting-started-guide/quick-overview.rst index 1a2bc809550..aa822ea6373 100644 --- a/doc/quick-overview.rst +++ b/doc/getting-started-guide/quick-overview.rst @@ -176,7 +176,7 @@ objects. You can think of it as a multi-dimensional generalization of the .. ipython:: python - ds = xr.Dataset({"foo": data, "bar": ("x", [1, 2]), "baz": np.pi}) + ds = xr.Dataset(dict(foo=data, bar=("x", [1, 2]), baz=np.pi)) ds diff --git a/doc/why-xarray.rst b/doc/getting-started-guide/why-xarray.rst similarity index 100% rename from doc/why-xarray.rst rename to doc/getting-started-guide/why-xarray.rst diff --git a/doc/howdoi.rst b/doc/howdoi.rst index 3604d66bd0c..c518b0daba6 100644 --- a/doc/howdoi.rst +++ b/doc/howdoi.rst @@ -23,6 +23,8 @@ How do I ... - :py:meth:`Dataset.set_coords` * - change the order of dimensions - :py:meth:`DataArray.transpose`, :py:meth:`Dataset.transpose` + * - reshape dimensions + - :py:meth:`DataArray.stack`, :py:meth:`Dataset.stack`, :py:meth:`Dataset.coarsen.construct`, :py:meth:`DataArray.coarsen.construct` * - remove a variable from my object - :py:meth:`Dataset.drop_vars`, :py:meth:`DataArray.drop_vars` * - remove dimensions of length 1 or 0 @@ -34,7 +36,9 @@ How do I ... * - rename a variable, dimension or coordinate - :py:meth:`Dataset.rename`, :py:meth:`DataArray.rename`, :py:meth:`Dataset.rename_vars`, :py:meth:`Dataset.rename_dims`, * - convert a DataArray to Dataset or vice versa - - :py:meth:`DataArray.to_dataset`, :py:meth:`Dataset.to_array` + - :py:meth:`DataArray.to_dataset`, :py:meth:`Dataset.to_array`, :py:meth:`Dataset.to_stacked_array`, :py:meth:`DataArray.to_unstacked_dataset` + * - extract variables that have certain attributes + - :py:meth:`Dataset.filter_by_attrs` * - extract the underlying array (e.g. numpy or Dask arrays) - :py:attr:`DataArray.data` * - convert to and extract the underlying numpy array @@ -43,6 +47,8 @@ How do I ... - :py:func:`dask.is_dask_collection` * - know how much memory my object requires - :py:attr:`DataArray.nbytes`, :py:attr:`Dataset.nbytes` + * - Get axis number for a dimension + - :py:meth:`DataArray.get_axis_num` * - convert a possibly irregularly sampled timeseries to a regularly sampled timeseries - :py:meth:`DataArray.resample`, :py:meth:`Dataset.resample` (see :ref:`resampling` for more) * - apply a function on all data variables in a Dataset @@ -51,6 +57,8 @@ How do I ... - :py:func:`Dataset.to_netcdf`, :py:func:`DataArray.to_netcdf` specifying ``engine="h5netcdf", invalid_netcdf=True`` * - make xarray objects look like other xarray objects - :py:func:`~xarray.ones_like`, :py:func:`~xarray.zeros_like`, :py:func:`~xarray.full_like`, :py:meth:`Dataset.reindex_like`, :py:meth:`Dataset.interp_like`, :py:meth:`Dataset.broadcast_like`, :py:meth:`DataArray.reindex_like`, :py:meth:`DataArray.interp_like`, :py:meth:`DataArray.broadcast_like` + * - Make sure my datasets have values at the same coordinate locations + - ``xr.align(dataset_1, dataset_2, join="exact")`` * - replace NaNs with other values - :py:meth:`Dataset.fillna`, :py:meth:`Dataset.ffill`, :py:meth:`Dataset.bfill`, :py:meth:`Dataset.interpolate_na`, :py:meth:`DataArray.fillna`, :py:meth:`DataArray.ffill`, :py:meth:`DataArray.bfill`, :py:meth:`DataArray.interpolate_na` * - extract the year, month, day or similar from a DataArray of time values @@ -59,3 +67,7 @@ How do I ... - ``obj.dt.ceil``, ``obj.dt.floor``, ``obj.dt.round``. See :ref:`dt_accessor` for more. * - make a mask that is ``True`` where an object contains any of the values in a array - :py:meth:`Dataset.isin`, :py:meth:`DataArray.isin` + * - Index using a boolean mask + - :py:meth:`Dataset.query`, :py:meth:`DataArray.query`, :py:meth:`Dataset.where`, :py:meth:`DataArray.where` + * - preserve ``attrs`` during (most) xarray operations + - ``xr.set_options(keep_attrs=True)`` diff --git a/doc/index.rst b/doc/index.rst index ee44d0ad4d9..c4c9d89264b 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -22,122 +22,56 @@ computing. .. _dask: http://dask.org .. _netCDF: http://www.unidata.ucar.edu/software/netcdf -Documentation -------------- - -**Getting Started** - -* :doc:`why-xarray` -* :doc:`faq` -* :doc:`quick-overview` -* :doc:`examples` -* :doc:`installing` .. toctree:: - :maxdepth: 1 + :maxdepth: 2 :hidden: - :caption: Getting Started - - why-xarray - faq - quick-overview - examples - installing - -**User Guide** - -* :doc:`terminology` -* :doc:`data-structures` -* :doc:`indexing` -* :doc:`interpolation` -* :doc:`computation` -* :doc:`groupby` -* :doc:`reshaping` -* :doc:`combining` -* :doc:`time-series` -* :doc:`weather-climate` -* :doc:`pandas` -* :doc:`io` -* :doc:`dask` -* :doc:`plotting` -* :doc:`duckarrays` + :caption: For users + + Getting Started + User Guide + Gallery + Tutorials & Videos + API Reference + How do I ... + Ecosystem .. toctree:: - :maxdepth: 1 + :maxdepth: 2 :hidden: - :caption: User Guide - - terminology - data-structures - indexing - interpolation - computation - groupby - reshaping - combining - time-series - weather-climate - pandas - io - dask - plotting - duckarrays - -**Help & reference** - -* :doc:`whats-new` -* :doc:`howdoi` -* :doc:`api` -* :doc:`internals` -* :doc:`roadmap` -* :doc:`contributing` -* :doc:`related-projects` + :caption: For developers/contributors + + Contributing Guide + Xarray Internals + Development Roadmap + Team + What’s New + GitHub repository .. toctree:: :maxdepth: 1 :hidden: - :caption: Help & reference + :caption: Community + + GitHub discussions + StackOverflow - whats-new - howdoi - api - internals - roadmap - contributing - related-projects -See also --------- -- `Xarray's Tutorial`_ presented at the 2020 SciPy Conference (`video recording`_). -- Stephan Hoyer and Joe Hamman's `Journal of Open Research Software paper`_ describing the xarray project. -- The `UW eScience Institute's Geohackweek`_ tutorial on xarray for geospatial data scientists. -- Stephan Hoyer's `SciPy2015 talk`_ introducing xarray to a general audience. -- Stephan Hoyer's `2015 Unidata Users Workshop talk`_ and `tutorial`_ (`with answers`_) introducing - xarray to users familiar with netCDF. -- `Nicolas Fauchereau's tutorial`_ on xarray for netCDF users. - -.. _Xarray's Tutorial: https://xarray-contrib.github.io/xarray-tutorial/ -.. _video recording: https://youtu.be/mecN-Ph_-78 -.. _Journal of Open Research Software paper: http://doi.org/10.5334/jors.148 -.. _UW eScience Institute's Geohackweek : https://geohackweek.github.io/nDarrays/ -.. _SciPy2015 talk: https://www.youtube.com/watch?v=X0pAhJgySxk -.. _2015 Unidata Users Workshop talk: https://www.youtube.com/watch?v=J9ypQOnt5l8 -.. _tutorial: https://github.com/Unidata/unidata-users-workshop/blob/master/notebooks/xray-tutorial.ipynb -.. _with answers: https://github.com/Unidata/unidata-users-workshop/blob/master/notebooks/xray-tutorial-with-answers.ipynb -.. _Nicolas Fauchereau's tutorial: http://nbviewer.iPython.org/github/nicolasfauchereau/metocean/blob/master/notebooks/xray.ipynb Get in touch ------------ -- Ask usage questions ("How do I?") on `StackOverflow`_. +- If you have a question like "How do I concatenate a list of datasets?", ask on `GitHub discussions`_ or `StackOverflow`_. + Please include a self-contained reproducible example if possible. - Report bugs, suggest features or view the source code `on GitHub`_. - For less well defined questions or ideas, or to announce other projects of - interest to xarray users, use the `mailing list`_. + interest to xarray users, use `GitHub discussions`_ or the `mailing list`_. -.. _StackOverFlow: http://stackoverflow.com/questions/tagged/python-xarray +.. _StackOverFlow: https://stackoverflow.com/questions/tagged/python-xarray +.. _Github discussions: https://github.com/pydata/xarray/discussions .. _mailing list: https://groups.google.com/forum/#!forum/xarray -.. _on GitHub: http://github.com/pydata/xarray +.. _on GitHub: https://github.com/pydata/xarray NumFOCUS -------- diff --git a/doc/internals.rst b/doc/internals.rst deleted file mode 100644 index 60d32128c60..00000000000 --- a/doc/internals.rst +++ /dev/null @@ -1,233 +0,0 @@ -.. _internals: - -xarray Internals -================ - -.. currentmodule:: xarray - -xarray builds upon two of the foundational libraries of the scientific Python -stack, NumPy and pandas. It is written in pure Python (no C or Cython -extensions), which makes it easy to develop and extend. Instead, we push -compiled code to :ref:`optional dependencies`. - -Variable objects ----------------- - -The core internal data structure in xarray is the :py:class:`~xarray.Variable`, -which is used as the basic building block behind xarray's -:py:class:`~xarray.Dataset` and :py:class:`~xarray.DataArray` types. A -``Variable`` consists of: - -- ``dims``: A tuple of dimension names. -- ``data``: The N-dimensional array (typically, a NumPy or Dask array) storing - the Variable's data. It must have the same number of dimensions as the length - of ``dims``. -- ``attrs``: An ordered dictionary of metadata associated with this array. By - convention, xarray's built-in operations never use this metadata. -- ``encoding``: Another ordered dictionary used to store information about how - these variable's data is represented on disk. See :ref:`io.encoding` for more - details. - -``Variable`` has an interface similar to NumPy arrays, but extended to make use -of named dimensions. For example, it uses ``dim`` in preference to an ``axis`` -argument for methods like ``mean``, and supports :ref:`compute.broadcasting`. - -However, unlike ``Dataset`` and ``DataArray``, the basic ``Variable`` does not -include coordinate labels along each axis. - -``Variable`` is public API, but because of its incomplete support for labeled -data, it is mostly intended for advanced uses, such as in xarray itself or for -writing new backends. You can access the variable objects that correspond to -xarray objects via the (readonly) :py:attr:`Dataset.variables -` and -:py:attr:`DataArray.variable ` attributes. - - -.. _internals.duck_arrays: - -Integrating with duck arrays ----------------------------- - -.. warning:: - - This is a experimental feature. - -xarray can wrap custom :term:`duck array` objects as long as they define numpy's -``shape``, ``dtype`` and ``ndim`` properties and the ``__array__``, -``__array_ufunc__`` and ``__array_function__`` methods. - -In certain situations (e.g. when printing the collapsed preview of -variables of a ``Dataset``), xarray will display the repr of a :term:`duck array` -in a single line, truncating it to a certain number of characters. If that -would drop too much information, the :term:`duck array` may define a -``_repr_inline_`` method that takes ``max_width`` (number of characters) as an -argument: - -.. code:: python - - class MyDuckArray: - ... - - def _repr_inline_(self, max_width): - """ format to a single line with at most max_width characters """ - ... - - ... - - -Extending xarray ----------------- - -.. ipython:: python - :suppress: - - import numpy as np - import pandas as pd - import xarray as xr - - np.random.seed(123456) - -xarray is designed as a general purpose library, and hence tries to avoid -including overly domain specific functionality. But inevitably, the need for more -domain specific logic arises. - -One standard solution to this problem is to subclass Dataset and/or DataArray to -add domain specific functionality. However, inheritance is not very robust. It's -easy to inadvertently use internal APIs when subclassing, which means that your -code may break when xarray upgrades. Furthermore, many builtin methods will -only return native xarray objects. - -The standard advice is to use `composition over inheritance`__, but -reimplementing an API as large as xarray's on your own objects can be an onerous -task, even if most methods are only forwarding to xarray implementations. - -__ https://github.com/pydata/xarray/issues/706 - -If you simply want the ability to call a function with the syntax of a -method call, then the builtin :py:meth:`~xarray.DataArray.pipe` method (copied -from pandas) may suffice. - -To resolve this issue for more complex cases, xarray has the -:py:func:`~xarray.register_dataset_accessor` and -:py:func:`~xarray.register_dataarray_accessor` decorators for adding custom -"accessors" on xarray objects. Here's how you might use these decorators to -write a custom "geo" accessor implementing a geography specific extension to -xarray: - -.. literalinclude:: examples/_code/accessor_example.py - -In general, the only restriction on the accessor class is that the ``__init__`` method -must have a single parameter: the ``Dataset`` or ``DataArray`` object it is supposed -to work on. - -This achieves the same result as if the ``Dataset`` class had a cached property -defined that returns an instance of your class: - -.. code-block:: python - - class Dataset: - ... - - @property - def geo(self): - return GeoAccessor(self) - -However, using the register accessor decorators is preferable to simply adding -your own ad-hoc property (i.e., ``Dataset.geo = property(...)``), for several -reasons: - -1. It ensures that the name of your property does not accidentally conflict with - any other attributes or methods (including other accessors). -2. Instances of accessor object will be cached on the xarray object that creates - them. This means you can save state on them (e.g., to cache computed - properties). -3. Using an accessor provides an implicit namespace for your custom - functionality that clearly identifies it as separate from built-in xarray - methods. - -.. note:: - - Accessors are created once per DataArray and Dataset instance. New - instances, like those created from arithmetic operations or when accessing - a DataArray from a Dataset (ex. ``ds[var_name]``), will have new - accessors created. - -Back in an interactive IPython session, we can use these properties: - -.. ipython:: python - :suppress: - - exec(open("examples/_code/accessor_example.py").read()) - -.. ipython:: python - - ds = xr.Dataset({"longitude": np.linspace(0, 10), "latitude": np.linspace(0, 20)}) - ds.geo.center - ds.geo.plot() - -The intent here is that libraries that extend xarray could add such an accessor -to implement subclass specific functionality rather than using actual subclasses -or patching in a large number of domain specific methods. For further reading -on ways to write new accessors and the philosophy behind the approach, see -:issue:`1080`. - -To help users keep things straight, please `let us know -`_ if you plan to write a new accessor -for an open source library. In the future, we will maintain a list of accessors -and the libraries that implement them on this page. - -To make documenting accessors with ``sphinx`` and ``sphinx.ext.autosummary`` -easier, you can use `sphinx-autosummary-accessors`_. - -.. _sphinx-autosummary-accessors: https://sphinx-autosummary-accessors.readthedocs.io/ - -.. _zarr_encoding: - -Zarr Encoding Specification ---------------------------- - -In implementing support for the `Zarr `_ storage -format, Xarray developers made some *ad hoc* choices about how to store -NetCDF data in Zarr. -Future versions of the Zarr spec will likely include a more formal convention -for the storage of the NetCDF data model in Zarr; see -`Zarr spec repo `_ for ongoing -discussion. - -First, Xarray can only read and write Zarr groups. There is currently no support -for reading / writting individual Zarr arrays. Zarr groups are mapped to -Xarray ``Dataset`` objects. - -Second, from Xarray's point of view, the key difference between -NetCDF and Zarr is that all NetCDF arrays have *dimension names* while Zarr -arrays do not. Therefore, in order to store NetCDF data in Zarr, Xarray must -somehow encode and decode the name of each array's dimensions. - -To accomplish this, Xarray developers decided to define a special Zarr array -attribute: ``_ARRAY_DIMENSIONS``. The value of this attribute is a list of -dimension names (strings), for example ``["time", "lon", "lat"]``. When writing -data to Zarr, Xarray sets this attribute on all variables based on the variable -dimensions. When reading a Zarr group, Xarray looks for this attribute on all -arrays, raising an error if it can't be found. The attribute is used to define -the variable dimension names and then removed from the attributes dictionary -returned to the user. - -Because of these choices, Xarray cannot read arbitrary array data, but only -Zarr data with valid ``_ARRAY_DIMENSIONS`` attributes on each array. - -After decoding the ``_ARRAY_DIMENSIONS`` attribute and assigning the variable -dimensions, Xarray proceeds to [optionally] decode each variable using its -standard CF decoding machinery used for NetCDF data (see :py:func:`decode_cf`). - -As a concrete example, here we write a tutorial dataset to Zarr and then -re-open it directly with Zarr: - -.. ipython:: python - - ds = xr.tutorial.load_dataset("rasm") - ds.to_zarr("rasm.zarr", mode="w") - import zarr - - zgroup = zarr.open("rasm.zarr") - print(zgroup.tree()) - dict(zgroup["Tair"].attrs) diff --git a/doc/internals/duck-arrays-integration.rst b/doc/internals/duck-arrays-integration.rst new file mode 100644 index 00000000000..1f492c82a62 --- /dev/null +++ b/doc/internals/duck-arrays-integration.rst @@ -0,0 +1,51 @@ + +.. _internals.duck_arrays: + +Integrating with duck arrays +============================= + +.. warning:: + + This is a experimental feature. + +xarray can wrap custom :term:`duck array` objects as long as they define numpy's +``shape``, ``dtype`` and ``ndim`` properties and the ``__array__``, +``__array_ufunc__`` and ``__array_function__`` methods. + +In certain situations (e.g. when printing the collapsed preview of +variables of a ``Dataset``), xarray will display the repr of a :term:`duck array` +in a single line, truncating it to a certain number of characters. If that +would drop too much information, the :term:`duck array` may define a +``_repr_inline_`` method that takes ``max_width`` (number of characters) as an +argument: + +.. code:: python + + class MyDuckArray: + ... + + def _repr_inline_(self, max_width): + """format to a single line with at most max_width characters""" + ... + + ... + +To avoid duplicated information, this method must omit information about the shape and +:term:`dtype`. For example, the string representation of a ``dask`` array or a +``sparse`` matrix would be: + +.. ipython:: python + + import dask.array as da + import xarray as xr + import sparse + + a = da.linspace(0, 1, 20, chunks=2) + a + + b = np.eye(10) + b[[5, 7, 3, 0], [6, 8, 2, 9]] = 2 + b = sparse.COO.from_numpy(b) + b + + xr.Dataset(dict(a=("x", a), b=(("y", "z"), b))) diff --git a/doc/internals/extending-xarray.rst b/doc/internals/extending-xarray.rst new file mode 100644 index 00000000000..ef26f30689e --- /dev/null +++ b/doc/internals/extending-xarray.rst @@ -0,0 +1,103 @@ + +Extending xarray +================ + +.. ipython:: python + :suppress: + + import xarray as xr + + +xarray is designed as a general purpose library, and hence tries to avoid +including overly domain specific functionality. But inevitably, the need for more +domain specific logic arises. + +One standard solution to this problem is to subclass Dataset and/or DataArray to +add domain specific functionality. However, inheritance is not very robust. It's +easy to inadvertently use internal APIs when subclassing, which means that your +code may break when xarray upgrades. Furthermore, many builtin methods will +only return native xarray objects. + +The standard advice is to use `composition over inheritance`__, but +reimplementing an API as large as xarray's on your own objects can be an onerous +task, even if most methods are only forwarding to xarray implementations. + +__ https://github.com/pydata/xarray/issues/706 + +If you simply want the ability to call a function with the syntax of a +method call, then the builtin :py:meth:`~xarray.DataArray.pipe` method (copied +from pandas) may suffice. + +To resolve this issue for more complex cases, xarray has the +:py:func:`~xarray.register_dataset_accessor` and +:py:func:`~xarray.register_dataarray_accessor` decorators for adding custom +"accessors" on xarray objects. Here's how you might use these decorators to +write a custom "geo" accessor implementing a geography specific extension to +xarray: + +.. literalinclude:: ../examples/_code/accessor_example.py + +In general, the only restriction on the accessor class is that the ``__init__`` method +must have a single parameter: the ``Dataset`` or ``DataArray`` object it is supposed +to work on. + +This achieves the same result as if the ``Dataset`` class had a cached property +defined that returns an instance of your class: + +.. code-block:: python + + class Dataset: + ... + + @property + def geo(self): + return GeoAccessor(self) + +However, using the register accessor decorators is preferable to simply adding +your own ad-hoc property (i.e., ``Dataset.geo = property(...)``), for several +reasons: + +1. It ensures that the name of your property does not accidentally conflict with + any other attributes or methods (including other accessors). +2. Instances of accessor object will be cached on the xarray object that creates + them. This means you can save state on them (e.g., to cache computed + properties). +3. Using an accessor provides an implicit namespace for your custom + functionality that clearly identifies it as separate from built-in xarray + methods. + +.. note:: + + Accessors are created once per DataArray and Dataset instance. New + instances, like those created from arithmetic operations or when accessing + a DataArray from a Dataset (ex. ``ds[var_name]``), will have new + accessors created. + +Back in an interactive IPython session, we can use these properties: + +.. ipython:: python + :suppress: + + exec(open("examples/_code/accessor_example.py").read()) + +.. ipython:: python + + ds = xr.Dataset({"longitude": np.linspace(0, 10), "latitude": np.linspace(0, 20)}) + ds.geo.center + ds.geo.plot() + +The intent here is that libraries that extend xarray could add such an accessor +to implement subclass specific functionality rather than using actual subclasses +or patching in a large number of domain specific methods. For further reading +on ways to write new accessors and the philosophy behind the approach, see +:issue:`1080`. + +To help users keep things straight, please `let us know +`_ if you plan to write a new accessor +for an open source library. In the future, we will maintain a list of accessors +and the libraries that implement them on this page. + +To make documenting accessors with ``sphinx`` and ``sphinx.ext.autosummary`` +easier, you can use `sphinx-autosummary-accessors`_. + +.. _sphinx-autosummary-accessors: https://sphinx-autosummary-accessors.readthedocs.io/ diff --git a/doc/internals/how-to-add-new-backend.rst b/doc/internals/how-to-add-new-backend.rst new file mode 100644 index 00000000000..251b0a17325 --- /dev/null +++ b/doc/internals/how-to-add-new-backend.rst @@ -0,0 +1,457 @@ +.. _add_a_backend: + +How to add a new backend +------------------------ + +Adding a new backend for read support to Xarray does not require +to integrate any code in Xarray; all you need to do is: + +- Create a class that inherits from Xarray :py:class:`~xarray.backends.BackendEntrypoint` + and implements the method ``open_dataset`` see :ref:`RST backend_entrypoint` + +- Declare this class as an external plugin in your ``setup.py``, see :ref:`RST backend_registration` + +If you also want to support lazy loading and dask see :ref:`RST lazy_loading`. + +Note that the new interface for backends is available from Xarray +version >= 0.18 onwards. + +.. _RST backend_entrypoint: + +BackendEntrypoint subclassing ++++++++++++++++++++++++++++++ + +Your ``BackendEntrypoint`` sub-class is the primary interface with Xarray, and +it should implement the following attributes and methods: + +- the ``open_dataset`` method (mandatory) +- the ``open_dataset_parameters`` attribute (optional) +- the ``guess_can_open`` method (optional). + +This is what a ``BackendEntrypoint`` subclass should look like: + +.. code-block:: python + + from xarray.backends import BackendEntrypoint + + + class MyBackendEntrypoint(BackendEntrypoint): + def open_dataset( + filename_or_obj, + *, + drop_variables=None, + # other backend specific keyword arguments + # `chunks` and `cache` DO NOT go here, they are handled by xarray + ): + return my_open_dataset(filename_or_obj, drop_variables=drop_variables) + + open_dataset_parameters = ["filename_or_obj", "drop_variables"] + + def guess_can_open(filename_or_obj): + try: + _, ext = os.path.splitext(filename_or_obj) + except TypeError: + return False + return ext in {".my_format", ".my_fmt"} + +``BackendEntrypoint`` subclass methods and attributes are detailed in the following. + +.. _RST open_dataset: + +open_dataset +^^^^^^^^^^^^ + +The backend ``open_dataset`` shall implement reading from file, the variables +decoding and it shall instantiate the output Xarray class :py:class:`~xarray.Dataset`. + +The following is an example of the high level processing steps: + +.. code-block:: python + + def open_dataset( + filename_or_obj, + *, + drop_variables=None, + decode_times=True, + decode_timedelta=True, + decode_coords=True, + my_backend_option=None, + ): + vars, attrs, coords = my_reader( + filename_or_obj, + drop_variables=drop_variables, + my_backend_option=my_backend_option, + ) + vars, attrs, coords = my_decode_variables( + vars, attrs, decode_times, decode_timedelta, decode_coords + ) # see also conventions.decode_cf_variables + + ds = xr.Dataset(vars, attrs=attrs, coords=coords) + ds.set_close(my_close_method) + + return ds + + +The output :py:class:`~xarray.Dataset` shall implement the additional custom method +``close``, used by Xarray to ensure the related files are eventually closed. This +method shall be set by using :py:meth:`~xarray.Dataset.set_close`. + + +The input of ``open_dataset`` method are one argument +(``filename_or_obj``) and one keyword argument (``drop_variables``): + +- ``filename_or_obj``: can be any object but usually it is a string containing a path or an instance of + :py:class:`pathlib.Path`. +- ``drop_variables``: can be `None` or an iterable containing the variable + names to be dropped when reading the data. + +If it makes sense for your backend, your ``open_dataset`` method +should implement in its interface the following boolean keyword arguments, called +**decoders**, which default to ``None``: + +- ``mask_and_scale`` +- ``decode_times`` +- ``decode_timedelta`` +- ``use_cftime`` +- ``concat_characters`` +- ``decode_coords`` + +Note: all the supported decoders shall be declared explicitly +in backend ``open_dataset`` signature and adding a ``**kargs`` is not allowed. + +These keyword arguments are explicitly defined in Xarray +:py:func:`~xarray.open_dataset` signature. Xarray will pass them to the +backend only if the User explicitly sets a value different from ``None``. +For more details on decoders see :ref:`RST decoders`. + +Your backend can also take as input a set of backend-specific keyword +arguments. All these keyword arguments can be passed to +:py:func:`~xarray.open_dataset` grouped either via the ``backend_kwargs`` +parameter or explicitly using the syntax ``**kwargs``. + + +If you don't want to support the lazy loading, then the +:py:class:`~xarray.Dataset` shall contain values as a :py:class:`numpy.ndarray` +and your work is almost done. + +.. _RST open_dataset_parameters: + +open_dataset_parameters +^^^^^^^^^^^^^^^^^^^^^^^ + +``open_dataset_parameters`` is the list of backend ``open_dataset`` parameters. +It is not a mandatory parameter, and if the backend does not provide it +explicitly, Xarray creates a list of them automatically by inspecting the +backend signature. + +If ``open_dataset_parameters`` is not defined, but ``**kwargs`` and ``*args`` +are in the backend ``open_dataset`` signature, Xarray raises an error. +On the other hand, if the backend provides the ``open_dataset_parameters``, +then ``**kwargs`` and ``*args`` can be used in the signature. +However, this practice is discouraged unless there is a good reasons for using +``**kwargs`` or ``*args``. + +.. _RST guess_can_open: + +guess_can_open +^^^^^^^^^^^^^^ + +``guess_can_open`` is used to identify the proper engine to open your data +file automatically in case the engine is not specified explicitly. If you are +not interested in supporting this feature, you can skip this step since +:py:class:`~xarray.backends.BackendEntrypoint` already provides a +default :py:meth:`~xarray.backends.BackendEntrypoint.guess_can_open` +that always returns ``False``. + +Backend ``guess_can_open`` takes as input the ``filename_or_obj`` parameter of +Xarray :py:meth:`~xarray.open_dataset`, and returns a boolean. + +.. _RST decoders: + +Decoders +^^^^^^^^ +The decoders implement specific operations to transform data from on-disk +representation to Xarray representation. + +A classic example is the “time” variable decoding operation. In NetCDF, the +elements of the “time” variable are stored as integers, and the unit contains +an origin (for example: "seconds since 1970-1-1"). In this case, Xarray +transforms the pair integer-unit in a :py:class:`numpy.datetime64`. + +The standard coders implemented in Xarray are: + +- :py:class:`xarray.coding.strings.CharacterArrayCoder()` +- :py:class:`xarray.coding.strings.EncodedStringCoder()` +- :py:class:`xarray.coding.variables.UnsignedIntegerCoder()` +- :py:class:`xarray.coding.variables.CFMaskCoder()` +- :py:class:`xarray.coding.variables.CFScaleOffsetCoder()` +- :py:class:`xarray.coding.times.CFTimedeltaCoder()` +- :py:class:`xarray.coding.times.CFDatetimeCoder()` + +Xarray coders all have the same interface. They have two methods: ``decode`` +and ``encode``. The method ``decode`` takes a ``Variable`` in on-disk +format and returns a ``Variable`` in Xarray format. Variable +attributes no more applicable after the decoding, are dropped and stored in the +``Variable.encoding`` to make them available to the ``encode`` method, which +performs the inverse transformation. + +In the following an example on how to use the coders ``decode`` method: + +.. ipython:: python + + var = xr.Variable( + dims=("x",), data=np.arange(10.0), attrs={"scale_factor": 10, "add_offset": 2} + ) + var + + coder = xr.coding.variables.CFScaleOffsetCoder() + decoded_var = coder.decode(var) + decoded_var + decoded_var.encoding + +Some of the transformations can be common to more backends, so before +implementing a new decoder, be sure Xarray does not already implement that one. + +The backends can reuse Xarray’s decoders, either instantiating the coders +and using the method ``decode`` directly or using the higher-level function +:py:func:`~xarray.conventions.decode_cf_variables` that groups Xarray decoders. + +In some cases, the transformation to apply strongly depends on the on-disk +data format. Therefore, you may need to implement your own decoder. + +An example of such a case is when you have to deal with the time format of a +grib file. grib format is very different from the NetCDF one: in grib, the +time is stored in two attributes dataDate and dataTime as strings. Therefore, +it is not possible to reuse the Xarray time decoder, and implementing a new +one is mandatory. + +Decoders can be activated or deactivated using the boolean keywords of +Xarray :py:meth:`~xarray.open_dataset` signature: ``mask_and_scale``, +``decode_times``, ``decode_timedelta``, ``use_cftime``, +``concat_characters``, ``decode_coords``. +Such keywords are passed to the backend only if the User sets a value +different from ``None``. Note that the backend does not necessarily have to +implement all the decoders, but it shall declare in its ``open_dataset`` +interface only the boolean keywords related to the supported decoders. + +.. _RST backend_registration: + +How to register a backend ++++++++++++++++++++++++++++ + +Define a new entrypoint in your ``setup.py`` (or ``setup.cfg``) with: + +- group: ``xarray.backends`` +- name: the name to be passed to :py:meth:`~xarray.open_dataset` as ``engine`` +- object reference: the reference of the class that you have implemented. + +You can declare the entrypoint in ``setup.py`` using the following syntax: + +.. code-block:: + + setuptools.setup( + entry_points={ + "xarray.backends": ["my_engine=my_package.my_module:MyBackendEntryClass"], + }, + ) + +in ``setup.cfg``: + +.. code-block:: cfg + + [options.entry_points] + xarray.backends = + my_engine = my_package.my_module:MyBackendEntryClass + + +See https://packaging.python.org/specifications/entry-points/#data-model +for more information + +If you are using `Poetry `_ for your build system, you can accomplish the same thing using "plugins". In this case you would need to add the following to your ``pyproject.toml`` file: + +.. code-block:: toml + + [tool.poetry.plugins."xarray_backends"] + "my_engine" = "my_package.my_module:MyBackendEntryClass" + +See https://python-poetry.org/docs/pyproject/#plugins for more information on Poetry plugins. + +.. _RST lazy_loading: + +How to support Lazy Loading ++++++++++++++++++++++++++++ +If you want to make your backend effective with big datasets, then you should +support lazy loading. +Basically, you shall replace the :py:class:`numpy.ndarray` inside the +variables with a custom class that supports lazy loading indexing. +See the example below: + +.. code-block:: python + + backend_array = MyBackendArray() + data = indexing.LazilyIndexedArray(backend_array) + var = xr.Variable(dims, data, attrs=attrs, encoding=encoding) + +Where: + +- :py:class:`~xarray.core.indexing.LazilyIndexedArray` is a class + provided by Xarray that manages the lazy loading. +- ``MyBackendArray`` shall be implemented by the backend and shall inherit + from :py:class:`~xarray.backends.BackendArray`. + +BackendArray subclassing +^^^^^^^^^^^^^^^^^^^^^^^^ + +The BackendArray subclass shall implement the following method and attributes: + +- the ``__getitem__`` method that takes in input an index and returns a + `NumPy `__ array +- the ``shape`` attribute +- the ``dtype`` attribute. + + +Xarray supports different type of +`indexing `__, that can be +grouped in three types of indexes +:py:class:`~xarray.core.indexing.BasicIndexer`, +:py:class:`~xarray.core.indexing.OuterIndexer` and +:py:class:`~xarray.core.indexing.VectorizedIndexer`. +This implies that the implementation of the method ``__getitem__`` can be tricky. +In oder to simplify this task, Xarray provides a helper function, +:py:func:`~xarray.core.indexing.explicit_indexing_adapter`, that transforms +all the input ``indexer`` types (`basic`, `outer`, `vectorized`) in a tuple +which is interpreted correctly by your backend. + +This is an example ``BackendArray`` subclass implementation: + +.. code-block:: python + + from xarray.backends import BackendArray + + + class MyBackendArray(BackendArray): + def __init__( + self, + shape, + dtype, + lock, + # other backend specific keyword arguments + ): + self.shape = shape + self.dtype = lock + self.lock = dtype + + def __getitem__( + self, key: xarray.core.indexing.ExplicitIndexer + ) -> np.typing.ArrayLike: + return indexing.explicit_indexing_adapter( + key, + self.shape, + indexing.IndexingSupport.BASIC, + self._raw_indexing_method, + ) + + def _raw_indexing_method(self, key: tuple) -> np.typing.ArrayLike: + # thread safe method that access to data on disk + with self.lock: + ... + return item + +Note that ``BackendArray.__getitem__`` must be thread safe to support +multi-thread processing. + +The :py:func:`~xarray.core.indexing.explicit_indexing_adapter` method takes in +input the ``key``, the array ``shape`` and the following parameters: + +- ``indexing_support``: the type of index supported by ``raw_indexing_method`` +- ``raw_indexing_method``: a method that shall take in input a key in the form + of a tuple and return an indexed :py:class:`numpy.ndarray`. + +For more details see +:py:class:`~xarray.core.indexing.IndexingSupport` and :ref:`RST indexing`. + +In order to support `Dask `__ distributed and +:py:mod:`multiprocessing`, ``BackendArray`` subclass should be serializable +either with :ref:`io.pickle` or +`cloudpickle `__. +That implies that all the reference to open files should be dropped. For +opening files, we therefore suggest to use the helper class provided by Xarray +:py:class:`~xarray.backends.CachingFileManager`. + +.. _RST indexing: + +Indexing Examples +^^^^^^^^^^^^^^^^^ +**BASIC** + +In the ``BASIC`` indexing support, numbers and slices are supported. + +Example: + +.. ipython:: + :verbatim: + + In [1]: # () shall return the full array + ...: backend_array._raw_indexing_method(()) + Out[1]: array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]) + + In [2]: # shall support integers + ...: backend_array._raw_indexing_method(1, 1) + Out[2]: 5 + + In [3]: # shall support slices + ...: backend_array._raw_indexing_method(slice(0, 3), slice(2, 4)) + Out[3]: array([[2, 3], [6, 7], [10, 11]]) + +**OUTER** + +The ``OUTER`` indexing shall support number, slices and in addition it shall +support also lists of integers. The the outer indexing is equivalent to +combining multiple input list with ``itertools.product()``: + +.. ipython:: + :verbatim: + + In [1]: backend_array._raw_indexing_method([0, 1], [0, 1, 2]) + Out[1]: array([[0, 1, 2], [4, 5, 6]]) + + # shall support integers + In [2]: backend_array._raw_indexing_method(1, 1) + Out[2]: 5 + + +**OUTER_1VECTOR** + +The ``OUTER_1VECTOR`` indexing shall supports number, slices and at most one +list. The behaviour with the list shall be the same of ``OUTER`` indexing. + +If you support more complex indexing as `explicit indexing` or +`numpy indexing`, you can have a look to the implemetation of Zarr backend and Scipy backend, +currently available in :py:mod:`~xarray.backends` module. + +.. _RST preferred_chunks: + +Backend preferred chunks +^^^^^^^^^^^^^^^^^^^^^^^^ + +The backend is not directly involved in `Dask `__ +chunking, since it is internally managed by Xarray. However, the backend can +define the preferred chunk size inside the variable’s encoding +``var.encoding["preferred_chunks"]``. The ``preferred_chunks`` may be useful +to improve performances with lazy loading. ``preferred_chunks`` shall be a +dictionary specifying chunk size per dimension like +``{“dim1”: 1000, “dim2”: 2000}`` or +``{“dim1”: [1000, 100], “dim2”: [2000, 2000, 2000]]}``. + +The ``preferred_chunks`` is used by Xarray to define the chunk size in some +special cases: + +- if ``chunks`` along a dimension is ``None`` or not defined +- if ``chunks`` is ``"auto"``. + +In the first case Xarray uses the chunks size specified in +``preferred_chunks``. +In the second case Xarray accommodates ideal chunk sizes, preserving if +possible the "preferred_chunks". The ideal chunk size is computed using +:py:func:`dask.array.core.normalize_chunks`, setting +``previous_chunks = preferred_chunks``. diff --git a/doc/internals/index.rst b/doc/internals/index.rst new file mode 100644 index 00000000000..8cedc67327e --- /dev/null +++ b/doc/internals/index.rst @@ -0,0 +1,20 @@ +.. _internals: + +xarray Internals +================ + +xarray builds upon two of the foundational libraries of the scientific Python +stack, NumPy and pandas. It is written in pure Python (no C or Cython +extensions), which makes it easy to develop and extend. Instead, we push +compiled code to :ref:`optional dependencies`. + + +.. toctree:: + :maxdepth: 2 + :hidden: + + variable-objects + duck-arrays-integration + extending-xarray + zarr-encoding-spec + how-to-add-new-backend diff --git a/doc/internals/variable-objects.rst b/doc/internals/variable-objects.rst new file mode 100644 index 00000000000..6ae3c2f7e6d --- /dev/null +++ b/doc/internals/variable-objects.rst @@ -0,0 +1,31 @@ +Variable objects +================ + +The core internal data structure in xarray is the :py:class:`~xarray.Variable`, +which is used as the basic building block behind xarray's +:py:class:`~xarray.Dataset` and :py:class:`~xarray.DataArray` types. A +``Variable`` consists of: + +- ``dims``: A tuple of dimension names. +- ``data``: The N-dimensional array (typically, a NumPy or Dask array) storing + the Variable's data. It must have the same number of dimensions as the length + of ``dims``. +- ``attrs``: An ordered dictionary of metadata associated with this array. By + convention, xarray's built-in operations never use this metadata. +- ``encoding``: Another ordered dictionary used to store information about how + these variable's data is represented on disk. See :ref:`io.encoding` for more + details. + +``Variable`` has an interface similar to NumPy arrays, but extended to make use +of named dimensions. For example, it uses ``dim`` in preference to an ``axis`` +argument for methods like ``mean``, and supports :ref:`compute.broadcasting`. + +However, unlike ``Dataset`` and ``DataArray``, the basic ``Variable`` does not +include coordinate labels along each axis. + +``Variable`` is public API, but because of its incomplete support for labeled +data, it is mostly intended for advanced uses, such as in xarray itself or for +writing new backends. You can access the variable objects that correspond to +xarray objects via the (readonly) :py:attr:`Dataset.variables +` and +:py:attr:`DataArray.variable ` attributes. diff --git a/doc/internals/zarr-encoding-spec.rst b/doc/internals/zarr-encoding-spec.rst new file mode 100644 index 00000000000..082d7984f59 --- /dev/null +++ b/doc/internals/zarr-encoding-spec.rst @@ -0,0 +1,65 @@ +.. currentmodule:: xarray + +.. _zarr_encoding: + +Zarr Encoding Specification +============================ + +In implementing support for the `Zarr `_ storage +format, Xarray developers made some *ad hoc* choices about how to store +NetCDF data in Zarr. +Future versions of the Zarr spec will likely include a more formal convention +for the storage of the NetCDF data model in Zarr; see +`Zarr spec repo `_ for ongoing +discussion. + +First, Xarray can only read and write Zarr groups. There is currently no support +for reading / writting individual Zarr arrays. Zarr groups are mapped to +Xarray ``Dataset`` objects. + +Second, from Xarray's point of view, the key difference between +NetCDF and Zarr is that all NetCDF arrays have *dimension names* while Zarr +arrays do not. Therefore, in order to store NetCDF data in Zarr, Xarray must +somehow encode and decode the name of each array's dimensions. + +To accomplish this, Xarray developers decided to define a special Zarr array +attribute: ``_ARRAY_DIMENSIONS``. The value of this attribute is a list of +dimension names (strings), for example ``["time", "lon", "lat"]``. When writing +data to Zarr, Xarray sets this attribute on all variables based on the variable +dimensions. When reading a Zarr group, Xarray looks for this attribute on all +arrays, raising an error if it can't be found. The attribute is used to define +the variable dimension names and then removed from the attributes dictionary +returned to the user. + +Because of these choices, Xarray cannot read arbitrary array data, but only +Zarr data with valid ``_ARRAY_DIMENSIONS`` attributes on each array. + +After decoding the ``_ARRAY_DIMENSIONS`` attribute and assigning the variable +dimensions, Xarray proceeds to [optionally] decode each variable using its +standard CF decoding machinery used for NetCDF data (see :py:func:`decode_cf`). + +Finally, it's worth noting that Xarray writes (and attempts to read) +"consolidated metadata" by default (the ``.zmetadata`` file), which is another +non-standard Zarr extension, albeit one implemented upstream in Zarr-Python. +You do not need to write consolidated metadata to make Zarr stores readable in +Xarray, but because Xarray can open these stores much faster, users will see a +warning about poor performance when reading non-consolidated stores unless they +explicitly set ``consolidated=False``. See :ref:`io.zarr.consolidated_metadata` +for more details. + +As a concrete example, here we write a tutorial dataset to Zarr and then +re-open it directly with Zarr: + +.. ipython:: python + + import os + import xarray as xr + import zarr + + ds = xr.tutorial.load_dataset("rasm") + ds.to_zarr("rasm.zarr", mode="w") + + zgroup = zarr.open("rasm.zarr") + print(os.listdir("rasm.zarr")) + print(zgroup.tree()) + dict(zgroup["Tair"].attrs) diff --git a/doc/roadmap.rst b/doc/roadmap.rst index 1cbbaf8ef42..dd5235bfb16 100644 --- a/doc/roadmap.rst +++ b/doc/roadmap.rst @@ -206,27 +206,10 @@ In order to lower this adoption barrier, we propose to: - Write a basic glossary that defines terms that might not be familiar to all (e.g. "lazy", "labeled", "serialization", "indexing", "backend"). + Administrative -------------- -Current core developers -~~~~~~~~~~~~~~~~~~~~~~~ - -- Stephan Hoyer -- Ryan Abernathey -- Joe Hamman -- Benoit Bovy -- Fabien Maussion -- Keisuke Fujii -- Maximilian Roos -- Deepak Cherian -- Spencer Clark -- Tom Nicholas -- Guido Imperiale -- Justus Magin -- Mathias Hauser -- Anderson Banihirwe - NumFOCUS ~~~~~~~~ diff --git a/doc/team.rst b/doc/team.rst new file mode 100644 index 00000000000..7b185dc3a52 --- /dev/null +++ b/doc/team.rst @@ -0,0 +1,87 @@ +Team +----- + +Current core developers +~~~~~~~~~~~~~~~~~~~~~~~ + +Xarray core developers are responsible for the ongoing organizational maintenance and technical direction of the xarray project. + +The current core developers team comprises: + +.. panels:: + :column: col-lg-4 col-md-4 col-sm-6 col-xs-12 p-2 + :card: text-center + + --- + .. image:: https://avatars.githubusercontent.com/u/1217238?v=4 + +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + :link-badge:`https://github.com/shoyer,"Stephan Hoyer",cls=btn badge-light` + + --- + .. image:: https://avatars.githubusercontent.com/u/1197350?v=4 + +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + :link-badge:`https://github.com/rabernat,"Ryan Abernathey",cls=btn badge-light` + + --- + .. image:: https://avatars.githubusercontent.com/u/2443309?v=4 + ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + :link-badge:`https://github.com/jhamman,"Joe Hamman",cls=btn badge-light` + + --- + .. image:: https://avatars.githubusercontent.com/u/4160723?v=4 + +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + :link-badge:`https://github.com/benbovy,"Benoit Bovy",cls=btn badge-light` + + --- + .. image:: https://avatars.githubusercontent.com/u/10050469?v=4 + ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + :link-badge:`https://github.com/fmaussion,"Fabien Maussion",cls=btn badge-light` + + --- + .. image:: https://avatars.githubusercontent.com/u/6815844?v=4 + +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + :link-badge:`https://github.com/fujiisoup,"Keisuke Fujii",cls=btn badge-light` + + --- + .. image:: https://avatars.githubusercontent.com/u/5635139?v=4 + +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + :link-badge:`https://github.com/max-sixty,"Maximilian Roos",cls=btn badge-light` + + --- + .. image:: https://avatars.githubusercontent.com/u/2448579?v=4 + +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + :link-badge:`https://github.com/dcherian,"Deepak Cherian",cls=btn badge-light` + + --- + .. image:: https://avatars.githubusercontent.com/u/6628425?v=4 + +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + :link-badge:`https://github.com/spencerkclark,"Spencer Clark",cls=btn badge-light` + + --- + .. image:: https://avatars.githubusercontent.com/u/35968931?v=4 + ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + :link-badge:`https://github.com/TomNicholas,"Tom Nicholas",cls=btn badge-light` + + --- + .. image:: https://avatars.githubusercontent.com/u/6213168?v=4 + ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + :link-badge:`https://github.com/crusaderky,"Guido Imperiale",cls=btn badge-light` + + --- + .. image:: https://avatars.githubusercontent.com/u/14808389?v=4 + ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + :link-badge:`https://github.com/keewis,"Justus Magin",cls=btn badge-light` + + --- + .. image:: https://avatars.githubusercontent.com/u/10194086?v=4 + ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + :link-badge:`https://github.com/mathause,"Mathias Hauser",cls=btn badge-light` + + --- + .. image:: https://avatars.githubusercontent.com/u/13301940?v=4 + ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + :link-badge:`https://github.com/andersy005,"Anderson Banihirwe",cls=btn badge-light` + + + +The full list of contributors is on our `GitHub Contributors Page `__. diff --git a/doc/tutorials-and-videos.rst b/doc/tutorials-and-videos.rst new file mode 100644 index 00000000000..22d501a4925 --- /dev/null +++ b/doc/tutorials-and-videos.rst @@ -0,0 +1,69 @@ + +Tutorials and Videos +==================== + + +Tutorials +---------- + +- `Xarray's Tutorials`_ repository +- The `UW eScience Institute's Geohackweek`_ tutorial on xarray for geospatial data scientists. +- `Nicolas Fauchereau's 2015 tutorial`_ on xarray for netCDF users. + + + +Videos +------- + +.. panels:: + :card: text-center + + --- + Xdev Python Tutorial Seminar Series 2021 seminar introducing Xarray (1 of 2) | Anderson Banihirwe + ^^^ + .. raw:: html + + + + --- + Xarray's virtual tutorial | October 2020 | Anderson Banihirwe, Deepak Cherian, and Martin Durant + ^^^ + .. raw:: html + + + + --- + Xarray's Tutorial presented at the 2020 SciPy Conference | Joe Hamman, Ryan Abernathey, + Deepak Cherian, and Stephan Hoyer + ^^^ + .. raw:: html + + + + --- + Scipy 2015 talk introducing xarray to a general audience | Stephan Hoyer + ^^^ + .. raw:: html + + + + --- + 2015 Unidata Users Workshop talk and tutorial with (`with answers`_) introducing + xarray to users familiar with netCDF | Stephan Hoyer + ^^^ + .. raw:: html + + + +Books, Chapters and Articles +----------------------------- + +- Stephan Hoyer and Joe Hamman's `Journal of Open Research Software paper`_ describing the xarray project. + + +.. _Xarray's Tutorials: https://xarray-contrib.github.io/xarray-tutorial/ +.. _Journal of Open Research Software paper: http://doi.org/10.5334/jors.148 +.. _UW eScience Institute's Geohackweek : https://geohackweek.github.io/nDarrays/ +.. _tutorial: https://github.com/Unidata/unidata-users-workshop/blob/master/notebooks/xray-tutorial.ipynb +.. _with answers: https://github.com/Unidata/unidata-users-workshop/blob/master/notebooks/xray-tutorial-with-answers.ipynb +.. _Nicolas Fauchereau's 2015 tutorial: http://nbviewer.iPython.org/github/nicolasfauchereau/metocean/blob/master/notebooks/xray.ipynb diff --git a/doc/combining.rst b/doc/user-guide/combining.rst similarity index 100% rename from doc/combining.rst rename to doc/user-guide/combining.rst diff --git a/doc/computation.rst b/doc/user-guide/computation.rst similarity index 89% rename from doc/computation.rst rename to doc/user-guide/computation.rst index dcfe270a942..c2f94b1f8e5 100644 --- a/doc/computation.rst +++ b/doc/user-guide/computation.rst @@ -6,6 +6,7 @@ Computation ########### + The labels associated with :py:class:`~xarray.DataArray` and :py:class:`~xarray.Dataset` objects enables some powerful shortcuts for computation, notably including aggregation and broadcasting by dimension @@ -443,6 +444,89 @@ The inverse operation is done with :py:meth:`~xarray.polyval`, .. note:: These methods replicate the behaviour of :py:func:`numpy.polyfit` and :py:func:`numpy.polyval`. + +.. _compute.curvefit: + +Fitting arbitrary functions +=========================== + +Xarray objects also provide an interface for fitting more complex functions using +:py:func:`scipy.optimize.curve_fit`. :py:meth:`~xarray.DataArray.curvefit` accepts +user-defined functions and can fit along multiple coordinates. + +For example, we can fit a relationship between two ``DataArray`` objects, maintaining +a unique fit at each spatial coordinate but aggregating over the time dimension: + +.. ipython:: python + + def exponential(x, a, xc): + return np.exp((x - xc) / a) + + + x = np.arange(-5, 5, 0.1) + t = np.arange(-5, 5, 0.1) + X, T = np.meshgrid(x, t) + Z1 = np.random.uniform(low=-5, high=5, size=X.shape) + Z2 = exponential(Z1, 3, X) + Z3 = exponential(Z1, 1, -X) + + ds = xr.Dataset( + data_vars=dict( + var1=(["t", "x"], Z1), var2=(["t", "x"], Z2), var3=(["t", "x"], Z3) + ), + coords={"t": t, "x": x}, + ) + ds[["var2", "var3"]].curvefit( + coords=ds.var1, + func=exponential, + reduce_dims="t", + bounds={"a": (0.5, 5), "xc": (-5, 5)}, + ) + +We can also fit multi-dimensional functions, and even use a wrapper function to +simultaneously fit a summation of several functions, such as this field containing +two gaussian peaks: + +.. ipython:: python + + def gaussian_2d(coords, a, xc, yc, xalpha, yalpha): + x, y = coords + z = a * np.exp( + -np.square(x - xc) / 2 / np.square(xalpha) + - np.square(y - yc) / 2 / np.square(yalpha) + ) + return z + + + def multi_peak(coords, *args): + z = np.zeros(coords[0].shape) + for i in range(len(args) // 5): + z += gaussian_2d(coords, *args[i * 5 : i * 5 + 5]) + return z + + + x = np.arange(-5, 5, 0.1) + y = np.arange(-5, 5, 0.1) + X, Y = np.meshgrid(x, y) + + n_peaks = 2 + names = ["a", "xc", "yc", "xalpha", "yalpha"] + names = [f"{name}{i}" for i in range(n_peaks) for name in names] + Z = gaussian_2d((X, Y), 3, 1, 1, 2, 1) + gaussian_2d((X, Y), 2, -1, -2, 1, 1) + Z += np.random.normal(scale=0.1, size=Z.shape) + + da = xr.DataArray(Z, dims=["y", "x"], coords={"y": y, "x": x}) + da.curvefit( + coords=["x", "y"], + func=multi_peak, + param_names=names, + kwargs={"maxfev": 10000}, + ) + +.. note:: + This method replicates the behavior of :py:func:`scipy.optimize.curve_fit`. + + .. _compute.broadcasting: Broadcasting by dimension name diff --git a/doc/dask.rst b/doc/user-guide/dask.rst similarity index 98% rename from doc/dask.rst rename to doc/user-guide/dask.rst index 4844967350b..321b7712b9f 100644 --- a/doc/dask.rst +++ b/doc/user-guide/dask.rst @@ -21,7 +21,7 @@ and at the `Dask examples website `_. What is a Dask array? --------------------- -.. image:: _static/dask_array.png +.. image:: ../_static/dask_array.png :width: 40 % :align: right :alt: A Dask array @@ -511,6 +511,13 @@ Notice that the 0-shaped sizes were not printed to screen. Since ``template`` ha mapped.identical(expected) +.. tip:: + + As :py:func:`map_blocks` loads each block into memory, reduce as much as possible objects consumed by user functions. + For example, drop useless variables before calling ``func`` with :py:func:`map_blocks`. + + + Chunking and performance ------------------------ diff --git a/doc/data-structures.rst b/doc/user-guide/data-structures.rst similarity index 99% rename from doc/data-structures.rst rename to doc/user-guide/data-structures.rst index ac78e1769d5..f59f561fb09 100644 --- a/doc/data-structures.rst +++ b/doc/user-guide/data-structures.rst @@ -239,7 +239,7 @@ to access any variable in a dataset, datasets have four key properties: used in ``data_vars`` (e.g., arrays of numbers, datetime objects or strings) - ``attrs``: :py:class:`dict` to hold arbitrary metadata -The distinction between whether a variables falls in data or coordinates +The distinction between whether a variable falls in data or coordinates (borrowed from `CF conventions`_) is mostly semantic, and you can probably get away with ignoring it if you like: dictionary like access on a dataset will supply variables found in either category. However, xarray does make use of the @@ -251,7 +251,7 @@ quantities that belong in data. Here is an example of how we might structure a dataset for a weather forecast: -.. image:: _static/dataset-diagram.png +.. image:: ../_static/dataset-diagram.png In this example, it would be natural to call ``temperature`` and ``precipitation`` "data variables" and all the other arrays "coordinate @@ -310,12 +310,12 @@ in the dictionary: .. ipython:: python - xr.Dataset({"bar": foo}) + xr.Dataset(dict(bar=foo)) .. ipython:: python - xr.Dataset({"bar": foo.to_pandas()}) + xr.Dataset(dict(bar=foo.to_pandas())) Where a pandas object is supplied as a value, the names of its indexes are used as dimension names, and its data is aligned to any existing dimensions. diff --git a/doc/duckarrays.rst b/doc/user-guide/duckarrays.rst similarity index 98% rename from doc/duckarrays.rst rename to doc/user-guide/duckarrays.rst index ba13d5160ae..97da968b84a 100644 --- a/doc/duckarrays.rst +++ b/doc/user-guide/duckarrays.rst @@ -42,10 +42,9 @@ the code will still cast to ``numpy`` arrays: :py:meth:`DataArray.interp` and :py:meth:`DataArray.interp_like` (uses ``scipy``): duck arrays in data variables and non-dimension coordinates will be casted in addition to not supporting duck arrays in dimension coordinates + * :py:meth:`Dataset.rolling` and :py:meth:`DataArray.rolling` (requires ``numpy>=1.20``) * :py:meth:`Dataset.rolling_exp` and :py:meth:`DataArray.rolling_exp` (uses ``numbagg``) - * :py:meth:`Dataset.rolling` and :py:meth:`DataArray.rolling` (uses internal functions - of ``numpy``) * :py:meth:`Dataset.interpolate_na` and :py:meth:`DataArray.interpolate_na` (uses :py:class:`numpy.vectorize`) * :py:func:`apply_ufunc` with ``vectorize=True`` (uses :py:class:`numpy.vectorize`) diff --git a/doc/groupby.rst b/doc/user-guide/groupby.rst similarity index 100% rename from doc/groupby.rst rename to doc/user-guide/groupby.rst diff --git a/doc/user-guide/index.rst b/doc/user-guide/index.rst new file mode 100644 index 00000000000..edeb0aac632 --- /dev/null +++ b/doc/user-guide/index.rst @@ -0,0 +1,27 @@ +########### +User Guide +########### + +In this user guide, you will find detailed descriptions and +examples that describe many common tasks that you can accomplish with xarray. + + +.. toctree:: + :maxdepth: 2 + :hidden: + + terminology + data-structures + indexing + interpolation + computation + groupby + reshaping + combining + time-series + weather-climate + pandas + io + dask + plotting + duckarrays diff --git a/doc/indexing.rst b/doc/user-guide/indexing.rst similarity index 93% rename from doc/indexing.rst rename to doc/user-guide/indexing.rst index 78766b8fd81..263bd4f431f 100644 --- a/doc/indexing.rst +++ b/doc/user-guide/indexing.rst @@ -234,9 +234,6 @@ arrays). However, you can do normal indexing with dimension names: ds[dict(space=[0], time=[0])] ds.loc[dict(time="2000-01-01")] -Using indexing to *assign* values to a subset of dataset (e.g., -``ds[dict(space=0)] = 1``) is not yet supported. - Dropping labels and dimensions ------------------------------ @@ -395,6 +392,22 @@ These methods may also be applied to ``Dataset`` objects ds = da.to_dataset(name="bar") ds.isel(x=xr.DataArray([0, 1, 2], dims=["points"])) +Vectorized indexing may be used to extract information from the nearest +grid cells of interest, for example, the nearest climate model grid cells +to a collection specified weather station latitudes and longitudes. + +.. ipython:: python + + ds = xr.tutorial.open_dataset("air_temperature") + + # Define target latitude and longitude (where weather stations might be) + target_lon = xr.DataArray([200, 201, 202, 205], dims="points") + target_lat = xr.DataArray([31, 41, 42, 42], dims="points") + + # Retrieve data at the grid cells nearest to the target latitudes and longitudes + da = ds["air"].sel(lon=target_lon, lat=target_lat, method="nearest") + da + .. tip:: If you are lazily loading your data from disk, not every form of vectorized @@ -520,6 +533,34 @@ __ https://docs.scipy.org/doc/numpy/user/basics.indexing.html#assigning-values-t da.isel(x=[0, 1, 2])[1] = -1 da +You can also assign values to all variables of a :py:class:`Dataset` at once: + +.. ipython:: python + + ds_org = xr.tutorial.open_dataset("eraint_uvz").isel( + latitude=slice(56, 59), longitude=slice(255, 258), level=0 + ) + # set all values to 0 + ds = xr.zeros_like(ds_org) + ds + + # by integer + ds[dict(latitude=2, longitude=2)] = 1 + ds["u"] + ds["v"] + + # by label + ds.loc[dict(latitude=47.25, longitude=[11.25, 12])] = 100 + ds["u"] + + # dataset as new values + new_dat = ds_org.loc[dict(latitude=48, longitude=[11.25, 12])] + new_dat + ds.loc[dict(latitude=47.25, longitude=[11.25, 12])] = new_dat + ds["u"] + +The dimensions can differ between the variables in the dataset, but all variables need to have at least the dimensions specified in the indexer dictionary. +The new values must be either a scalar, a :py:class:`DataArray` or a :py:class:`Dataset` itself that contains all variables that also appear in the dataset to be modified. .. _more_advanced_indexing: diff --git a/doc/interpolation.rst b/doc/user-guide/interpolation.rst similarity index 99% rename from doc/interpolation.rst rename to doc/user-guide/interpolation.rst index 9a3b7a7ee2d..8a7f9ebe911 100644 --- a/doc/interpolation.rst +++ b/doc/user-guide/interpolation.rst @@ -179,7 +179,7 @@ For example, if you want to interpolate a two dimensional array along a particul you can pass two 1-dimensional :py:class:`~xarray.DataArray` s with a common dimension as new coordinate. -.. image:: _static/advanced_selection_interpolation.svg +.. image:: ../_static/advanced_selection_interpolation.svg :height: 200px :width: 400 px :alt: advanced indexing and interpolation diff --git a/doc/io.rst b/doc/user-guide/io.rst similarity index 94% rename from doc/io.rst rename to doc/user-guide/io.rst index 2e46879929b..5ec3fa4a6b9 100644 --- a/doc/io.rst +++ b/doc/user-guide/io.rst @@ -246,7 +246,7 @@ See its docstring for more details. across the datasets (ignoring floating point differences). The following command with suitable modifications (such as ``parallel=True``) works well with such datasets:: - xr.open_mfdataset('my/files/*.nc', concat_dim="time", + xr.open_mfdataset('my/files/*.nc', concat_dim="time", combine="nested", data_vars='minimal', coords='minimal', compat='override') This command concatenates variables along the ``"time"`` dimension, but only those that @@ -621,7 +621,7 @@ over the network until we look at particular values: # the data is downloaded automatically when we make the plot In [6]: tmax[0].plot() -.. image:: _static/opendap-prism-tmax.png +.. image:: ../_static/opendap-prism-tmax.png Some servers require authentication before we can access the data. For this purpose we can explicitly create a :py:class:`backends.PydapDataStore` @@ -837,11 +837,6 @@ Xarray's Zarr backend allows xarray to leverage these capabilities, including the ability to store and analyze datasets far too large fit onto disk (particularly :ref:`in combination with dask `). -.. warning:: - - Zarr support is still an experimental feature. Please report any bugs or - unexepected behavior via github issues. - Xarray can't open just any zarr dataset, because xarray requires special metadata (attributes) describing the dataset dimensions and coordinates. At this time, xarray can only open zarr datasets that have been written by @@ -890,17 +885,44 @@ Cloud Storage Buckets It is possible to read and write xarray datasets directly from / to cloud storage buckets using zarr. This example uses the `gcsfs`_ package to provide -a ``MutableMapping`` interface to `Google Cloud Storage`_, which we can then -pass to xarray:: +an interface to `Google Cloud Storage`_. + +From v0.16.2: general `fsspec`_ URLs are parsed and the store set up for you +automatically when reading, such that you can open a dataset in a single +call. You should include any arguments to the storage backend as the +key ``storage_options``, part of ``backend_kwargs``. + +.. code:: python + + ds_gcs = xr.open_dataset( + "gcs:///path.zarr", + backend_kwargs={ + "storage_options": {"project": "", "token": None} + }, + engine="zarr", + ) + + +This also works with ``open_mfdataset``, allowing you to pass a list of paths or +a URL to be interpreted as a glob string. + +For older versions, and for writing, you must explicitly set up a ``MutableMapping`` +instance and pass this, as follows: + +.. code:: python import gcsfs - fs = gcsfs.GCSFileSystem(project='', token=None) - gcsmap = gcsfs.mapping.GCSMap('', gcs=fs, check=True, create=False) + + fs = gcsfs.GCSFileSystem(project="", token=None) + gcsmap = gcsfs.mapping.GCSMap("", gcs=fs, check=True, create=False) # write to the bucket ds.to_zarr(store=gcsmap) # read it back ds_gcs = xr.open_zarr(gcsmap) +(or use the utility function ``fsspec.get_mapper()``). + +.. _fsspec: https://filesystem-spec.readthedocs.io/en/latest/ .. _Zarr: http://zarr.readthedocs.io/ .. _Amazon S3: https://aws.amazon.com/s3/ .. _Google Cloud Storage: https://cloud.google.com/storage/ @@ -932,6 +954,8 @@ For example: Not all native zarr compression and filtering options have been tested with xarray. +.. _io.zarr.consolidated_metadata: + Consolidated Metadata ~~~~~~~~~~~~~~~~~~~~~ @@ -939,27 +963,27 @@ Xarray needs to read all of the zarr metadata when it opens a dataset. In some storage mediums, such as with cloud object storage (e.g. amazon S3), this can introduce significant overhead, because two separate HTTP calls to the object store must be made for each variable in the dataset. -With version 2.3, zarr will support a feature called *consolidated metadata*, -which allows all metadata for the entire dataset to be stored with a single -key (by default called ``.zmetadata``). This can drastically speed up -opening the store. (For more information on this feature, consult the +As of Xarray version 0.18, Xarray by default uses a feature called +*consolidated metadata*, storing all metadata for the entire dataset with a +single key (by default called ``.zmetadata``). This typically drastically speeds +up opening the store. (For more information on this feature, consult the `zarr docs `_.) -If you have zarr version 2.3 or greater, xarray can write and read stores -with consolidated metadata. To write consolidated metadata, pass the -``consolidated=True`` option to the -:py:attr:`Dataset.to_zarr` method:: - - ds.to_zarr('foo.zarr', consolidated=True) +By default, Xarray writes consolidated metadata and attempts to read stores +with consolidated metadata, falling back to use non-consolidated metadata for +reads. Because this fall-back option is so much slower, Xarray issues a +``RuntimeWarning`` with guidance when reading with consolidated metadata fails: -To read a consolidated store, pass the ``consolidated=True`` option to -:py:func:`open_zarr`:: + Failed to open Zarr store with consolidated metadata, falling back to try + reading non-consolidated metadata. This is typically much slower for + opening a dataset. To silence this warning, consider: - ds = xr.open_zarr('foo.zarr', consolidated=True) - -Xarray can't perform consolidation on pre-existing zarr datasets. This should -be done directly from zarr, as described in the -`zarr docs `_. + 1. Consolidating metadata in this existing store with + :py:func:`zarr.consolidate_metadata`. + 2. Explicitly setting ``consolidated=False``, to avoid trying to read + consolidate metadata. + 3. Explicitly setting ``consolidated=True``, to raise an error in this case + instead of falling back to try reading non-consolidated metadata. .. _io.zarr.appending: @@ -1040,7 +1064,7 @@ and then calling ``to_zarr`` with ``compute=False`` to write only metadata ds = xr.Dataset({"foo": ("x", dummies)}) path = "path/to/directory.zarr" # Now we write the metadata without computing any array values - ds.to_zarr(path, compute=False, consolidated=True) + ds.to_zarr(path, compute=False) Now, a Zarr store with the correct variable shapes and attributes exists that can be filled out by subsequent calls to ``to_zarr``. The ``region`` provides a @@ -1050,7 +1074,7 @@ data should be written (in index space, not coordinate space), e.g., .. ipython:: python # For convenience, we'll slice a single dataset, but in the real use-case - # we would create them separately, possibly even from separate processes. + # we would create them separately possibly even from separate processes. ds = xr.Dataset({"foo": ("x", np.arange(30))}) ds.isel(x=slice(0, 10)).to_zarr(path, region={"x": slice(0, 10)}) ds.isel(x=slice(10, 20)).to_zarr(path, region={"x": slice(10, 20)}) @@ -1106,7 +1130,7 @@ We recommend installing PyNIO via conda:: conda install -c conda-forge pynio - .. note:: +.. warning:: PyNIO is no longer actively maintained and conflicts with netcdf4 > 1.5.3. The PyNIO backend may be moved outside of xarray in the future. diff --git a/doc/pandas.rst b/doc/user-guide/pandas.rst similarity index 100% rename from doc/pandas.rst rename to doc/user-guide/pandas.rst diff --git a/doc/plotting.rst b/doc/user-guide/plotting.rst similarity index 93% rename from doc/plotting.rst rename to doc/user-guide/plotting.rst index 3699f794ae8..f1c76b21488 100644 --- a/doc/plotting.rst +++ b/doc/user-guide/plotting.rst @@ -227,7 +227,7 @@ from the time and assign it as a non-dimension coordinate: :okwarning: decimal_day = (air1d.time - air1d.time[0]) / pd.Timedelta("1d") - air1d_multi = air1d.assign_coords(decimal_day=("time", decimal_day)) + air1d_multi = air1d.assign_coords(decimal_day=("time", decimal_day.data)) air1d_multi To use ``'decimal_day'`` as x coordinate it must be explicitly specified: @@ -411,6 +411,37 @@ produce plots with nonuniform coordinates. @savefig plotting_nonuniform_coords.png width=4in b.plot() +==================== + Other types of plot +==================== + +There are several other options for plotting 2D data. + +Contour plot using :py:meth:`DataArray.plot.contour()` + +.. ipython:: python + :okwarning: + + @savefig plotting_contour.png width=4in + air2d.plot.contour() + +Filled contour plot using :py:meth:`DataArray.plot.contourf()` + +.. ipython:: python + :okwarning: + + @savefig plotting_contourf.png width=4in + air2d.plot.contourf() + +Surface plot using :py:meth:`DataArray.plot.surface()` + +.. ipython:: python + :okwarning: + + @savefig plotting_surface.png width=4in + # transpose just to make the example look a bit nicer + air2d.T.plot.surface() + ==================== Calling Matplotlib ==================== @@ -715,6 +746,9 @@ Consider this dataset ds +Scatter +~~~~~~~ + Suppose we want to scatter ``A`` against ``B`` .. ipython:: python @@ -762,6 +796,47 @@ Faceting is also possible For more advanced scatter plots, we recommend converting the relevant data variables to a pandas DataFrame and using the extensive plotting capabilities of ``seaborn``. +Quiver +~~~~~~ + +Visualizing vector fields is supported with quiver plots: + +.. ipython:: python + :okwarning: + + @savefig ds_simple_quiver.png + ds.isel(w=1, z=1).plot.quiver(x="x", y="y", u="A", v="B") + + +where ``u`` and ``v`` denote the x and y direction components of the arrow vectors. Again, faceting is also possible: + +.. ipython:: python + :okwarning: + + @savefig ds_facet_quiver.png + ds.plot.quiver(x="x", y="y", u="A", v="B", col="w", row="z", scale=4) + +``scale`` is required for faceted quiver plots. The scale determines the number of data units per arrow length unit, i.e. a smaller scale parameter makes the arrow longer. + +Streamplot +~~~~~~~~~~ + +Visualizing vector fields is also supported with streamline plots: + +.. ipython:: python + :okwarning: + + @savefig ds_simple_streamplot.png + ds.isel(w=1, z=1).plot.streamplot(x="x", y="y", u="A", v="B") + + +where ``u`` and ``v`` denote the x and y direction components of the vectors tangent to the streamlines. Again, faceting is also possible: + +.. ipython:: python + :okwarning: + + @savefig ds_facet_streamplot.png + ds.plot.streamplot(x="x", y="y", u="A", v="B", col="w", row="z") .. _plot-maps: diff --git a/doc/reshaping.rst b/doc/user-guide/reshaping.rst similarity index 100% rename from doc/reshaping.rst rename to doc/user-guide/reshaping.rst diff --git a/doc/terminology.rst b/doc/user-guide/terminology.rst similarity index 98% rename from doc/terminology.rst rename to doc/user-guide/terminology.rst index 3cfc211593f..1876058323e 100644 --- a/doc/terminology.rst +++ b/doc/user-guide/terminology.rst @@ -79,7 +79,7 @@ complete examples, please consult the relevant documentation.* example, multidimensional coordinates are often used in geoscience datasets when :doc:`the data's physical coordinates (such as latitude and longitude) differ from their logical coordinates - `. However, non-dimension coordinates + <../examples/multidimensional-coords>`. However, non-dimension coordinates are not indexed, and any operation on non-dimension coordinates that leverages indexing will fail. Printing ``arr.coords`` will print all of ``arr``'s coordinate names, with the corresponding dimension(s) in diff --git a/doc/time-series.rst b/doc/user-guide/time-series.rst similarity index 99% rename from doc/time-series.rst rename to doc/user-guide/time-series.rst index 96a2edc0ea5..f9d341ff25d 100644 --- a/doc/time-series.rst +++ b/doc/user-guide/time-series.rst @@ -224,4 +224,4 @@ Data that has indices outside of the given ``tolerance`` are set to ``NaN``. For more examples of using grouped operations on a time dimension, see -:doc:`examples/weather-data`. +:doc:`../examples/weather-data`. diff --git a/doc/weather-climate.rst b/doc/user-guide/weather-climate.rst similarity index 88% rename from doc/weather-climate.rst rename to doc/user-guide/weather-climate.rst index db612d74859..057bd3d0d54 100644 --- a/doc/weather-climate.rst +++ b/doc/user-guide/weather-climate.rst @@ -1,3 +1,5 @@ +.. currentmodule:: xarray + .. _weather-climate: Weather and climate data @@ -8,10 +10,40 @@ Weather and climate data import xarray as xr -``xarray`` can leverage metadata that follows the `Climate and Forecast (CF) conventions`_ if present. Examples include automatic labelling of plots with descriptive names and units if proper metadata is present (see :ref:`plotting`) and support for non-standard calendars used in climate science through the ``cftime`` module (see :ref:`CFTimeIndex`). There are also a number of geosciences-focused projects that build on xarray (see :ref:`related-projects`). +``xarray`` can leverage metadata that follows the `Climate and Forecast (CF) conventions`_ if present. Examples include automatic labelling of plots with descriptive names and units if proper metadata is present (see :ref:`plotting`) and support for non-standard calendars used in climate science through the ``cftime`` module (see :ref:`CFTimeIndex`). There are also a number of geosciences-focused projects that build on xarray (see :ref:`ecosystem`). .. _Climate and Forecast (CF) conventions: http://cfconventions.org +.. _cf_variables: + +Related Variables +----------------- + +Several CF variable attributes contain lists of other variables +associated with the variable with the attribute. A few of these are +now parsed by XArray, with the attribute value popped to encoding on +read and the variables in that value interpreted as non-dimension +coordinates: + +- ``coordinates`` +- ``bounds`` +- ``grid_mapping`` +- ``climatology`` +- ``geometry`` +- ``node_coordinates`` +- ``node_count`` +- ``part_node_count`` +- ``interior_ring`` +- ``cell_measures`` +- ``formula_terms`` + +This decoding is controlled by the ``decode_coords`` kwarg to +:py:func:`open_dataset` and :py:func:`open_mfdataset`. + +The CF attribute ``ancillary_variables`` was not included in the list +due to the variables listed there being associated primarily with the +variable with the attribute, rather than with the dimensions. + .. _metpy_accessor: CF-compliant coordinate variables diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 158ddf3350a..16555574649 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -15,50 +15,567 @@ What's New np.random.seed(123456) -.. _whats-new.0.16.3: +.. _whats-new.0.19.1: -v0.17.0 (unreleased) +v0.19.1 (unreleased) +--------------------- + +New Features +~~~~~~~~~~~~ +- Add a option to disable the use of ``bottleneck`` (:pull:`5560`) + By `Justus Magin `_. +- Added ``**kwargs`` argument to :py:meth:`open_rasterio` to access overviews (:issue:`3269`). + By `Pushkar Kopparla `_. + + +Breaking changes +~~~~~~~~~~~~~~~~ + + +Deprecations +~~~~~~~~~~~~ + + +Bug fixes +~~~~~~~~~ + + +Documentation +~~~~~~~~~~~~~ + + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Explicit indexes refactor: avoid ``len(index)`` in ``map_blocks`` (:pull:`5670`). + By `Deepak Cherian `_. +- Explicit indexes refactor: decouple ``xarray.Index``` from ``xarray.Variable`` (:pull:`5636`). + By `Benoit Bovy `_. +- Improve the performance of reprs for large datasets or dataarrays. (:pull:`5661`) + By `Jimmy Westling `_. + +.. _whats-new.0.19.0: + +v0.19.0 (23 July 2021) +---------------------- + +This release brings improvements to plotting of categorical data, the ability to specify how attributes +are combined in xarray operations, a new high-level :py:func:`unify_chunks` function, as well as various +deprecations, bug fixes, and minor improvements. + + +Many thanks to the 29 contributors to this release!: + +Andrew Williams, Augustus, Aureliana Barghini, Benoit Bovy, crusaderky, Deepak Cherian, ellesmith88, +Elliott Sales de Andrade, Giacomo Caria, github-actions[bot], Illviljan, Joeperdefloep, joooeey, Julia Kent, +Julius Busecke, keewis, Mathias Hauser, Matthias Göbel, Mattia Almansi, Maximilian Roos, Peter Andreas Entschev, +Ray Bell, Sander, Santiago Soler, Sebastian, Spencer Clark, Stephan Hoyer, Thomas Hirtz, Thomas Nicholas. + +New Features +~~~~~~~~~~~~ +- Allow passing argument ``missing_dims`` to :py:meth:`Variable.transpose` and :py:meth:`Dataset.transpose` + (:issue:`5550`, :pull:`5586`) + By `Giacomo Caria `_. +- Allow passing a dictionary as coords to a :py:class:`DataArray` (:issue:`5527`, + reverts :pull:`1539`, which had deprecated this due to python's inconsistent ordering in earlier versions). + By `Sander van Rijn `_. +- Added :py:meth:`Dataset.coarsen.construct`, :py:meth:`DataArray.coarsen.construct` (:issue:`5454`, :pull:`5475`). + By `Deepak Cherian `_. +- Xarray now uses consolidated metadata by default when writing and reading Zarr + stores (:issue:`5251`). + By `Stephan Hoyer `_. +- New top-level function :py:func:`unify_chunks`. + By `Mattia Almansi `_. +- Allow assigning values to a subset of a dataset using positional or label-based + indexing (:issue:`3015`, :pull:`5362`). + By `Matthias Göbel `_. +- Attempting to reduce a weighted object over missing dimensions now raises an error (:pull:`5362`). + By `Mattia Almansi `_. +- Add ``.sum`` to :py:meth:`~xarray.DataArray.rolling_exp` and + :py:meth:`~xarray.Dataset.rolling_exp` for exponentially weighted rolling + sums. These require numbagg 0.2.1; + (:pull:`5178`). + By `Maximilian Roos `_. +- :py:func:`xarray.cov` and :py:func:`xarray.corr` now lazily check for missing + values if inputs are dask arrays (:issue:`4804`, :pull:`5284`). + By `Andrew Williams `_. +- Attempting to ``concat`` list of elements that are not all ``Dataset`` or all ``DataArray`` now raises an error (:issue:`5051`, :pull:`5425`). + By `Thomas Hirtz `_. +- allow passing a function to ``combine_attrs`` (:pull:`4896`). + By `Justus Magin `_. +- Allow plotting categorical data (:pull:`5464`). + By `Jimmy Westling `_. +- Allow removal of the coordinate attribute ``coordinates`` on variables by setting ``.attrs['coordinates']= None`` + (:issue:`5510`). + By `Elle Smith `_. +- Added :py:meth:`DataArray.to_numpy`, :py:meth:`DataArray.as_numpy`, and :py:meth:`Dataset.as_numpy`. (:pull:`5568`). + By `Tom Nicholas `_. +- Units in plot labels are now automatically inferred from wrapped :py:meth:`pint.Quantity` arrays. (:pull:`5561`). + By `Tom Nicholas `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- The default ``mode`` for :py:meth:`Dataset.to_zarr` when ``region`` is set + has changed to the new ``mode="r+"``, which only allows for overriding + pre-existing array values. This is a safer default than the prior ``mode="a"``, + and allows for higher performance writes (:pull:`5252`). + By `Stephan Hoyer `_. +- The main parameter to :py:func:`combine_by_coords` is renamed to `data_objects` instead + of `datasets` so anyone calling this method using a named parameter will need to update + the name accordingly (:issue:`3248`, :pull:`4696`). + By `Augustus Ijams `_. + +Deprecations +~~~~~~~~~~~~ + +- Removed the deprecated ``dim`` kwarg to :py:func:`DataArray.integrate` (:pull:`5630`) +- Removed the deprecated ``keep_attrs`` kwarg to :py:func:`DataArray.rolling` (:pull:`5630`) +- Removed the deprecated ``keep_attrs`` kwarg to :py:func:`DataArray.coarsen` (:pull:`5630`) +- Completed deprecation of passing an ``xarray.DataArray`` to :py:func:`Variable` - will now raise a ``TypeError`` (:pull:`5630`) + +Bug fixes +~~~~~~~~~ +- Fix a minor incompatibility between partial datetime string indexing with a + :py:class:`CFTimeIndex` and upcoming pandas version 1.3.0 (:issue:`5356`, + :pull:`5359`). + By `Spencer Clark `_. +- Fix 1-level multi-index incorrectly converted to single index (:issue:`5384`, + :pull:`5385`). + By `Benoit Bovy `_. +- Don't cast a duck array in a coordinate to :py:class:`numpy.ndarray` in + :py:meth:`DataArray.differentiate` (:pull:`5408`) + By `Justus Magin `_. +- Fix the ``repr`` of :py:class:`Variable` objects with ``display_expand_data=True`` + (:pull:`5406`) + By `Justus Magin `_. +- Plotting a pcolormesh with ``xscale="log"`` and/or ``yscale="log"`` works as + expected after improving the way the interval breaks are generated (:issue:`5333`). + By `Santiago Soler `_ +- :py:func:`combine_by_coords` can now handle combining a list of unnamed + ``DataArray`` as input (:issue:`3248`, :pull:`4696`). + By `Augustus Ijams `_. + + +Internal Changes +~~~~~~~~~~~~~~~~ +- Run CI on the first & last python versions supported only; currently 3.7 & 3.9. + (:pull:`5433`) + By `Maximilian Roos `_. +- Publish test results & timings on each PR. + (:pull:`5537`) + By `Maximilian Roos `_. +- Explicit indexes refactor: add a ``xarray.Index.query()`` method in which + one may eventually provide a custom implementation of label-based data + selection (not ready yet for public use). Also refactor the internal, + pandas-specific implementation into ``PandasIndex.query()`` and + ``PandasMultiIndex.query()`` (:pull:`5322`). + By `Benoit Bovy `_. + +.. _whats-new.0.18.2: + +v0.18.2 (19 May 2021) +--------------------- + +This release reverts a regression in xarray's unstacking of dask-backed arrays. + +.. _whats-new.0.18.1: + +v0.18.1 (18 May 2021) +--------------------- + +This release is intended as a small patch release to be compatible with the new +2021.5.0 ``dask.distributed`` release. It also includes a new +``drop_duplicates`` method, some documentation improvements, the beginnings of +our internal Index refactoring, and some bug fixes. + +Thank you to all 16 contributors! + +Anderson Banihirwe, Andrew, Benoit Bovy, Brewster Malevich, Giacomo Caria, +Illviljan, James Bourbeau, Keewis, Maximilian Roos, Ravin Kumar, Stephan Hoyer, +Thomas Nicholas, Tom Nicholas, Zachary Moon. + +New Features +~~~~~~~~~~~~ +- Implement :py:meth:`DataArray.drop_duplicates` + to remove duplicate dimension values (:pull:`5239`). + By `Andrew Huang `_. +- Allow passing ``combine_attrs`` strategy names to the ``keep_attrs`` parameter of + :py:func:`apply_ufunc` (:pull:`5041`) + By `Justus Magin `_. +- :py:meth:`Dataset.interp` now allows interpolation with non-numerical datatypes, + such as booleans, instead of dropping them. (:issue:`4761` :pull:`5008`). + By `Jimmy Westling `_. +- Raise more informative error when decoding time variables with invalid reference dates. + (:issue:`5199`, :pull:`5288`). By `Giacomo Caria `_. + + +Bug fixes +~~~~~~~~~ +- Opening netCDF files from a path that doesn't end in ``.nc`` without supplying + an explicit ``engine`` works again (:issue:`5295`), fixing a bug introduced in + 0.18.0. + By `Stephan Hoyer `_ + +Documentation +~~~~~~~~~~~~~ +- Clean up and enhance docstrings for the :py:class:`DataArray.plot` and ``Dataset.plot.*`` + families of methods (:pull:`5285`). + By `Zach Moon `_. + +- Explanation of deprecation cycles and how to implement them added to contributors + guide. (:pull:`5289`) + By `Tom Nicholas `_. + + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Explicit indexes refactor: add an ``xarray.Index`` base class and + ``Dataset.xindexes`` / ``DataArray.xindexes`` properties. Also rename + ``PandasIndexAdapter`` to ``PandasIndex``, which now inherits from + ``xarray.Index`` (:pull:`5102`). + By `Benoit Bovy `_. +- Replace ``SortedKeysDict`` with python's ``dict``, given dicts are now ordered. + By `Maximilian Roos `_. +- Updated the release guide for developers. Now accounts for actions that are automated via github + actions. (:pull:`5274`). + By `Tom Nicholas `_. + +.. _whats-new.0.18.0: + +v0.18.0 (6 May 2021) -------------------- +This release brings a few important performance improvements, a wide range of +usability upgrades, lots of bug fixes, and some new features. These include +a plugin API to add backend engines, a new theme for the documentation, +curve fitting methods, and several new plotting functions. + +Many thanks to the 38 contributors to this release: Aaron Spring, Alessandro Amici, +Alex Marandon, Alistair Miles, Ana Paula Krelling, Anderson Banihirwe, Aureliana Barghini, +Baudouin Raoult, Benoit Bovy, Blair Bonnett, David Trémouilles, Deepak Cherian, +Gabriel Medeiros Abrahão, Giacomo Caria, Hauke Schulz, Illviljan, Mathias Hauser, Matthias Bussonnier, +Mattia Almansi, Maximilian Roos, Ray Bell, Richard Kleijn, Ryan Abernathey, Sam Levang, Spencer Clark, +Spencer Jones, Tammas Loughran, Tobias Kölling, Todd, Tom Nicholas, Tom White, Victor Negîrneac, +Xianxiang Li, Zeb Nicholls, crusaderky, dschwoerer, johnomotani, keewis + + +New Features +~~~~~~~~~~~~ + +- apply ``combine_attrs`` on data variables and coordinate variables when concatenating + and merging datasets and dataarrays (:pull:`4902`). + By `Justus Magin `_. +- Add :py:meth:`Dataset.to_pandas` (:pull:`5247`) + By `Giacomo Caria `_. +- Add :py:meth:`DataArray.plot.surface` which wraps matplotlib's `plot_surface` to make + surface plots (:issue:`2235` :issue:`5084` :pull:`5101`). + By `John Omotani `_. +- Allow passing multiple arrays to :py:meth:`Dataset.__setitem__` (:pull:`5216`). + By `Giacomo Caria `_. +- Add 'cumulative' option to :py:meth:`Dataset.integrate` and + :py:meth:`DataArray.integrate` so that result is a cumulative integral, like + :py:func:`scipy.integrate.cumulative_trapezoidal` (:pull:`5153`). + By `John Omotani `_. +- Add ``safe_chunks`` option to :py:meth:`Dataset.to_zarr` which allows overriding + checks made to ensure Dask and Zarr chunk compatibility (:issue:`5056`). + By `Ryan Abernathey `_ +- Add :py:meth:`Dataset.query` and :py:meth:`DataArray.query` which enable indexing + of datasets and data arrays by evaluating query expressions against the values of the + data variables (:pull:`4984`). + By `Alistair Miles `_. +- Allow passing ``combine_attrs`` to :py:meth:`Dataset.merge` (:pull:`4895`). + By `Justus Magin `_. +- Support for `dask.graph_manipulation + `_ (requires dask >=2021.3) + By `Guido Imperiale `_ +- Add :py:meth:`Dataset.plot.streamplot` for streamplot plots with :py:class:`Dataset` + variables (:pull:`5003`). + By `John Omotani `_. +- Many of the arguments for the :py:attr:`DataArray.str` methods now support + providing an array-like input. In this case, the array provided to the + arguments is broadcast against the original array and applied elementwise. +- :py:attr:`DataArray.str` now supports ``+``, ``*``, and ``%`` operators. These + behave the same as they do for :py:class:`str`, except that they follow + array broadcasting rules. +- A large number of new :py:attr:`DataArray.str` methods were implemented, + :py:meth:`DataArray.str.casefold`, :py:meth:`DataArray.str.cat`, + :py:meth:`DataArray.str.extract`, :py:meth:`DataArray.str.extractall`, + :py:meth:`DataArray.str.findall`, :py:meth:`DataArray.str.format`, + :py:meth:`DataArray.str.get_dummies`, :py:meth:`DataArray.str.islower`, + :py:meth:`DataArray.str.join`, :py:meth:`DataArray.str.normalize`, + :py:meth:`DataArray.str.partition`, :py:meth:`DataArray.str.rpartition`, + :py:meth:`DataArray.str.rsplit`, and :py:meth:`DataArray.str.split`. + A number of these methods allow for splitting or joining the strings in an + array. (:issue:`4622`) + By `Todd Jennings `_ +- Thanks to the new pluggable backend infrastructure external packages may now + use the ``xarray.backends`` entry point to register additional engines to be used in + :py:func:`open_dataset`, see the documentation in :ref:`add_a_backend` + (:issue:`4309`, :issue:`4803`, :pull:`4989`, :pull:`4810` and many others). + The backend refactor has been sponsored with the "Essential Open Source Software for Science" + grant from the `Chan Zuckerberg Initiative `_ and + developed by `B-Open `_. + By `Aureliana Barghini `_ and `Alessandro Amici `_. +- :py:attr:`~core.accessor_dt.DatetimeAccessor.date` added (:issue:`4983`, :pull:`4994`). + By `Hauke Schulz `_. +- Implement ``__getitem__`` for both :py:class:`~core.groupby.DatasetGroupBy` and + :py:class:`~core.groupby.DataArrayGroupBy`, inspired by pandas' + :py:meth:`~pandas.core.groupby.GroupBy.get_group`. + By `Deepak Cherian `_. +- Switch the tutorial functions to use `pooch `_ + (which is now a optional dependency) and add :py:func:`tutorial.open_rasterio` as a + way to open example rasterio files (:issue:`3986`, :pull:`4102`, :pull:`5074`). + By `Justus Magin `_. +- Add typing information to unary and binary arithmetic operators operating on + :py:class:`Dataset`, :py:class:`DataArray`, :py:class:`Variable`, + :py:class:`~core.groupby.DatasetGroupBy` or + :py:class:`~core.groupby.DataArrayGroupBy` (:pull:`4904`). + By `Richard Kleijn `_. +- Add a ``combine_attrs`` parameter to :py:func:`open_mfdataset` (:pull:`4971`). + By `Justus Magin `_. +- Enable passing arrays with a subset of dimensions to + :py:meth:`DataArray.clip` & :py:meth:`Dataset.clip`; these methods now use + :py:func:`xarray.apply_ufunc`; (:pull:`5184`). + By `Maximilian Roos `_. +- Disable the `cfgrib` backend if the `eccodes` library is not installed (:pull:`5083`). + By `Baudouin Raoult `_. +- Added :py:meth:`DataArray.curvefit` and :py:meth:`Dataset.curvefit` for general curve fitting applications. (:issue:`4300`, :pull:`4849`) + By `Sam Levang `_. +- Add options to control expand/collapse of sections in display of Dataset and + DataArray. The function :py:func:`set_options` now takes keyword aguments + ``display_expand_attrs``, ``display_expand_coords``, ``display_expand_data``, + ``display_expand_data_vars``, all of which can be one of ``True`` to always + expand, ``False`` to always collapse, or ``default`` to expand unless over a + pre-defined limit (:pull:`5126`). + By `Tom White `_. +- Significant speedups in :py:meth:`Dataset.interp` and :py:meth:`DataArray.interp`. + (:issue:`4739`, :pull:`4740`). + By `Deepak Cherian `_. +- Prevent passing `concat_dim` to :py:func:`xarray.open_mfdataset` when + `combine='by_coords'` is specified, which should never have been possible (as + :py:func:`xarray.combine_by_coords` has no `concat_dim` argument to pass to). + Also removes unneeded internal reordering of datasets in + :py:func:`xarray.open_mfdataset` when `combine='by_coords'` is specified. + Fixes (:issue:`5230`). + By `Tom Nicholas `_. +- Implement ``__setitem__`` for ``xarray.core.indexing.DaskIndexingAdapter`` if + dask version supports item assignment. (:issue:`5171`, :pull:`5174`) + By `Tammas Loughran `_. + +Breaking changes +~~~~~~~~~~~~~~~~ +- The minimum versions of some dependencies were changed: + + ============ ====== ==== + Package Old New + ============ ====== ==== + boto3 1.12 1.13 + cftime 1.0 1.1 + dask 2.11 2.15 + distributed 2.11 2.15 + matplotlib 3.1 3.2 + numba 0.48 0.49 + ============ ====== ==== + +- :py:func:`open_dataset` and :py:func:`open_dataarray` now accept only the first argument + as positional, all others need to be passed are keyword arguments. This is part of the + refactor to support external backends (:issue:`4309`, :pull:`4989`). + By `Alessandro Amici `_. +- Functions that are identities for 0d data return the unchanged data + if axis is empty. This ensures that Datasets where some variables do + not have the averaged dimensions are not accidentially changed + (:issue:`4885`, :pull:`5207`). + By `David Schwörer `_. +- :py:attr:`DataArray.coarsen` and :py:attr:`Dataset.coarsen` no longer support passing ``keep_attrs`` + via its constructor. Pass ``keep_attrs`` via the applied function, i.e. use + ``ds.coarsen(...).mean(keep_attrs=False)`` instead of ``ds.coarsen(..., keep_attrs=False).mean()``. + Further, coarsen now keeps attributes per default (:pull:`5227`). + By `Mathias Hauser `_. +- switch the default of the :py:func:`merge` ``combine_attrs`` parameter to + ``"override"``. This will keep the current behavior for merging the ``attrs`` of + variables but stop dropping the ``attrs`` of the main objects (:pull:`4902`). + By `Justus Magin `_. + +Deprecations +~~~~~~~~~~~~ + +- Warn when passing `concat_dim` to :py:func:`xarray.open_mfdataset` when + `combine='by_coords'` is specified, which should never have been possible (as + :py:func:`xarray.combine_by_coords` has no `concat_dim` argument to pass to). + Also removes unneeded internal reordering of datasets in + :py:func:`xarray.open_mfdataset` when `combine='by_coords'` is specified. + Fixes (:issue:`5230`), via (:pull:`5231`, :pull:`5255`). + By `Tom Nicholas `_. +- The `lock` keyword argument to :py:func:`open_dataset` and :py:func:`open_dataarray` is now + a backend specific option. It will give a warning if passed to a backend that doesn't support it + instead of being silently ignored. From the next version it will raise an error. + This is part of the refactor to support external backends (:issue:`5073`). + By `Tom Nicholas `_ and `Alessandro Amici `_. + + +Bug fixes +~~~~~~~~~ +- Properly support :py:meth:`DataArray.ffill`, :py:meth:`DataArray.bfill`, :py:meth:`Dataset.ffill`, :py:meth:`Dataset.bfill` along chunked dimensions. + (:issue:`2699`). + By `Deepak Cherian `_. +- Fix 2d plot failure for certain combinations of dimensions when `x` is 1d and `y` is + 2d (:issue:`5097`, :pull:`5099`). + By `John Omotani `_. +- Ensure standard calendar times encoded with large values (i.e. greater than + approximately 292 years), can be decoded correctly without silently overflowing + (:pull:`5050`). This was a regression in xarray 0.17.0. + By `Zeb Nicholls `_. +- Added support for `numpy.bool_` attributes in roundtrips using `h5netcdf` engine with `invalid_netcdf=True` [which casts `bool`s to `numpy.bool_`] (:issue:`4981`, :pull:`4986`). + By `Victor Negîrneac `_. +- Don't allow passing ``axis`` to :py:meth:`Dataset.reduce` methods (:issue:`3510`, :pull:`4940`). + By `Justus Magin `_. +- Decode values as signed if attribute `_Unsigned = "false"` (:issue:`4954`) + By `Tobias Kölling `_. +- Keep coords attributes when interpolating when the indexer is not a Variable. (:issue:`4239`, :issue:`4839` :pull:`5031`) + By `Jimmy Westling `_. +- Ensure standard calendar dates encoded with a calendar attribute with some or + all uppercase letters can be decoded or encoded to or from + ``np.datetime64[ns]`` dates with or without ``cftime`` installed + (:issue:`5093`, :pull:`5180`). + By `Spencer Clark `_. +- Warn on passing ``keep_attrs`` to ``resample`` and ``rolling_exp`` as they are ignored, pass ``keep_attrs`` + to the applied function instead (:pull:`5265`). + By `Mathias Hauser `_. + +Documentation +~~~~~~~~~~~~~ +- New section on :ref:`add_a_backend` in the "Internals" chapter aimed to backend developers + (:issue:`4803`, :pull:`4810`). + By `Aureliana Barghini `_. +- Add :py:meth:`Dataset.polyfit` and :py:meth:`DataArray.polyfit` under "See also" in + the docstrings of :py:meth:`Dataset.polyfit` and :py:meth:`DataArray.polyfit` + (:issue:`5016`, :pull:`5020`). + By `Aaron Spring `_. +- New sphinx theme & rearrangement of the docs (:pull:`4835`). + By `Anderson Banihirwe `_. + +Internal Changes +~~~~~~~~~~~~~~~~ +- Enable displaying mypy error codes and ignore only specific error codes using + ``# type: ignore[error-code]`` (:pull:`5096`). + By `Mathias Hauser `_. +- Replace uses of ``raises_regex`` with the more standard + ``pytest.raises(Exception, match="foo")``; + (:pull:`5188`), (:pull:`5191`). + By `Maximilian Roos `_. + +.. _whats-new.0.17.0: + +v0.17.0 (24 Feb 2021) +--------------------- + +This release brings a few important performance improvements, a wide range of +usability upgrades, lots of bug fixes, and some new features. These include +better ``cftime`` support, a new quiver plot, better ``unstack`` performance, +more efficient memory use in rolling operations, and some python packaging +improvements. We also have a few documentation improvements (and more planned!). + +Many thanks to the 36 contributors to this release: Alessandro Amici, Anderson +Banihirwe, Aureliana Barghini, Ayrton Bourn, Benjamin Bean, Blair Bonnett, Chun +Ho Chow, DWesl, Daniel Mesejo-León, Deepak Cherian, Eric Keenan, Illviljan, Jens +Hedegaard Nielsen, Jody Klymak, Julien Seguinot, Julius Busecke, Kai Mühlbauer, +Leif Denby, Martin Durant, Mathias Hauser, Maximilian Roos, Michael Mann, Ray +Bell, RichardScottOZ, Spencer Clark, Tim Gates, Tom Nicholas, Yunus Sevinchan, +alexamici, aurghs, crusaderky, dcherian, ghislainp, keewis, rhkleijn + Breaking changes ~~~~~~~~~~~~~~~~ - xarray no longer supports python 3.6 - The minimum versions of some other dependencies were changed: + The minimum version policy was changed to also apply to projects with irregular + releases. As a result, the minimum versions of some dependencies have changed: + ============ ====== ==== Package Old New ============ ====== ==== Python 3.6 3.7 setuptools 38.4 40.4 + numpy 1.15 1.17 + pandas 0.25 1.0 + dask 2.9 2.11 + distributed 2.9 2.11 + bottleneck 1.2 1.3 + h5netcdf 0.7 0.8 + iris 2.2 2.4 + netcdf4 1.4 1.5 + pseudonetcdf 3.0 3.1 + rasterio 1.0 1.1 + scipy 1.3 1.4 + seaborn 0.9 0.10 + zarr 2.3 2.4 ============ ====== ==== - (:issue:`4688`, :pull:`4720`) - By `Justus Magin `_. + (:issue:`4688`, :pull:`4720`, :pull:`4907`, :pull:`4942`) - As a result of :pull:`4684` the default units encoding for datetime-like values (``np.datetime64[ns]`` or ``cftime.datetime``) will now always be set such that ``int64`` values can be used. In the past, no units finer than "seconds" were chosen, which would sometimes mean that ``float64`` values were required, which would lead to inaccurate I/O round-trips. -- remove deprecated ``autoclose`` kwargs from :py:func:`open_dataset` (:pull:`4725`). - By `Aureliana Barghini `_. +- Variables referred to in attributes like ``bounds`` and ``grid_mapping`` + can be set as coordinate variables. These attributes are moved to + :py:attr:`DataArray.encoding` from :py:attr:`DataArray.attrs`. This behaviour + is controlled by the ``decode_coords`` kwarg to :py:func:`open_dataset` and + :py:func:`open_mfdataset`. The full list of decoded attributes is in + :ref:`weather-climate` (:pull:`2844`, :issue:`3689`) +- As a result of :pull:`4911` the output from calling :py:meth:`DataArray.sum` + or :py:meth:`DataArray.prod` on an integer array with ``skipna=True`` and a + non-None value for ``min_count`` will now be a float array rather than an + integer array. Deprecations ~~~~~~~~~~~~ - ``dim`` argument to :py:meth:`DataArray.integrate` is being deprecated in favour of a ``coord`` argument, for consistency with :py:meth:`Dataset.integrate`. - For now using ``dim`` issues a ``FutureWarning``. By `Tom Nicholas `_. - + For now using ``dim`` issues a ``FutureWarning``. It will be removed in + version 0.19.0 (:pull:`3993`). + By `Tom Nicholas `_. +- Deprecated ``autoclose`` kwargs from :py:func:`open_dataset` are removed (:pull:`4725`). + By `Aureliana Barghini `_. +- the return value of :py:meth:`Dataset.update` is being deprecated to make it work more + like :py:meth:`dict.update`. It will be removed in version 0.19.0 (:pull:`4932`). + By `Justus Magin `_. New Features ~~~~~~~~~~~~ +- :py:meth:`~xarray.cftime_range` and :py:meth:`DataArray.resample` now support + millisecond (``"L"`` or ``"ms"``) and microsecond (``"U"`` or ``"us"``) frequencies + for ``cftime.datetime`` coordinates (:issue:`4097`, :pull:`4758`). + By `Spencer Clark `_. - Significantly higher ``unstack`` performance on numpy-backed arrays which - contain missing values; 8x faster in our benchmark, and 2x faster than pandas. - (:pull:`4746`); + contain missing values; 8x faster than previous versions in our benchmark, and + now 2x faster than pandas (:pull:`4746`). By `Maximilian Roos `_. - -- Performance improvement when constructing DataArrays. Significantly speeds up repr for Datasets with large number of variables. +- Add :py:meth:`Dataset.plot.quiver` for quiver plots with :py:class:`Dataset` variables. + By `Deepak Cherian `_. +- Add ``"drop_conflicts"`` to the strategies supported by the ``combine_attrs`` kwarg + (:issue:`4749`, :pull:`4827`). + By `Justus Magin `_. +- Allow installing from git archives (:pull:`4897`). + By `Justus Magin `_. +- :py:class:`~core.rolling.DataArrayCoarsen` and :py:class:`~core.rolling.DatasetCoarsen` + now implement a ``reduce`` method, enabling coarsening operations with custom + reduction functions (:issue:`3741`, :pull:`4939`). + By `Spencer Clark `_. +- Most rolling operations use significantly less memory. (:issue:`4325`). By `Deepak Cherian `_. +- Add :py:meth:`Dataset.drop_isel` and :py:meth:`DataArray.drop_isel` + (:issue:`4658`, :pull:`4819`). + By `Daniel Mesejo `_. +- Xarray now leverages updates as of cftime version 1.4.1, which enable exact I/O + roundtripping of ``cftime.datetime`` objects (:pull:`4758`). + By `Spencer Clark `_. +- :py:func:`open_dataset` and :py:func:`open_mfdataset` now accept ``fsspec`` URLs + (including globs for the latter) for ``engine="zarr"``, and so allow reading from + many remote and other file systems (:pull:`4461`) + By `Martin Durant `_ - :py:meth:`DataArray.swap_dims` & :py:meth:`Dataset.swap_dims` now accept dims in the form of kwargs as well as a dict, like most similar methods. By `Maximilian Roos `_. @@ -69,47 +586,85 @@ New Features Bug fixes ~~~~~~~~~ -- :py:meth:`DataArray.resample` and :py:meth:`Dataset.resample` do not trigger computations anymore if :py:meth:`Dataset.weighted` or :py:meth:`DataArray.weighted` are applied (:issue:`4625`, :pull:`4668`). By `Julius Busecke `_. -- :py:func:`merge` with ``combine_attrs='override'`` makes a copy of the attrs (:issue:`4627`). -- By default, when possible, xarray will now always use values of type ``int64`` when encoding - and decoding ``numpy.datetime64[ns]`` datetimes. This ensures that maximum - precision and accuracy are maintained in the round-tripping process - (:issue:`4045`, :pull:`4684`). It also enables encoding and decoding standard calendar - dates with time units of nanoseconds (:pull:`4400`). By `Spencer Clark - `_ and `Mark Harfouche `_. +- Use specific type checks in ``xarray.core.variable.as_compatible_data`` instead of + blanket access to ``values`` attribute (:issue:`2097`) + By `Yunus Sevinchan `_. +- :py:meth:`DataArray.resample` and :py:meth:`Dataset.resample` do not trigger + computations anymore if :py:meth:`Dataset.weighted` or + :py:meth:`DataArray.weighted` are applied (:issue:`4625`, :pull:`4668`). By + `Julius Busecke `_. +- :py:func:`merge` with ``combine_attrs='override'`` makes a copy of the attrs + (:issue:`4627`). +- By default, when possible, xarray will now always use values of + type ``int64`` when encoding and decoding ``numpy.datetime64[ns]`` datetimes. This + ensures that maximum precision and accuracy are maintained in the round-tripping + process (:issue:`4045`, :pull:`4684`). It also enables encoding and decoding standard + calendar dates with time units of nanoseconds (:pull:`4400`). + By `Spencer Clark `_ and `Mark Harfouche + `_. - :py:meth:`DataArray.astype`, :py:meth:`Dataset.astype` and :py:meth:`Variable.astype` support the ``order`` and ``subok`` parameters again. This fixes a regression introduced in version 0.16.1 (:issue:`4644`, :pull:`4683`). By `Richard Kleijn `_ . - Remove dictionary unpacking when using ``.loc`` to avoid collision with ``.sel`` parameters (:pull:`4695`). - By `Anderson Banihirwe `_ + By `Anderson Banihirwe `_. - Fix the legend created by :py:meth:`Dataset.plot.scatter` (:issue:`4641`, :pull:`4723`). By `Justus Magin `_. -- Fix a crash in orthogonal indexing on geographic coordinates with ``engine='cfgrib'`` (:issue:`4733` :pull:`4737`). - By `Alessandro Amici `_ +- Fix a crash in orthogonal indexing on geographic coordinates with ``engine='cfgrib'`` + (:issue:`4733` :pull:`4737`). + By `Alessandro Amici `_. - Coordinates with dtype ``str`` or ``bytes`` now retain their dtype on many operations, e.g. ``reindex``, ``align``, ``concat``, ``assign``, previously they were cast to an object dtype - (:issue:`2658` and :issue:`4543`) by `Mathias Hauser `_. -- Limit number of data rows when printing large datasets. (:issue:`4736`, :pull:`4750`). By `Jimmy Westling `_. -- Add ``missing_dims`` parameter to transpose (:issue:`4647`, :pull:`4767`). By `Daniel Mesejo `_. + (:issue:`2658` and :issue:`4543`). + By `Mathias Hauser `_. +- Limit number of data rows when printing large datasets. (:issue:`4736`, :pull:`4750`). + By `Jimmy Westling `_. +- Add ``missing_dims`` parameter to transpose (:issue:`4647`, :pull:`4767`). + By `Daniel Mesejo `_. - Resolve intervals before appending other metadata to labels when plotting (:issue:`4322`, :pull:`4794`). By `Justus Magin `_. - Fix regression when decoding a variable with a ``scale_factor`` and ``add_offset`` given - as a list of length one (:issue:`4631`) by `Mathias Hauser `_. + as a list of length one (:issue:`4631`). + By `Mathias Hauser `_. - Expand user directory paths (e.g. ``~/``) in :py:func:`open_mfdataset` and :py:meth:`Dataset.to_zarr` (:issue:`4783`, :pull:`4795`). By `Julien Seguinot `_. -- Add :py:meth:`Dataset.drop_isel` and :py:meth:`DataArray.drop_isel` (:issue:`4658`, :pull:`4819`). By `Daniel Mesejo `_. +- Raise DeprecationWarning when trying to typecast a tuple containing a :py:class:`DataArray`. + User now prompted to first call `.data` on it (:issue:`4483`). + By `Chun Ho Chow `_. +- Ensure that :py:meth:`Dataset.interp` raises ``ValueError`` when interpolating + outside coordinate range and ``bounds_error=True`` (:issue:`4854`, + :pull:`4855`). + By `Leif Denby `_. +- Fix time encoding bug associated with using cftime versions greater than + 1.4.0 with xarray (:issue:`4870`, :pull:`4871`). + By `Spencer Clark `_. +- Stop :py:meth:`DataArray.sum` and :py:meth:`DataArray.prod` computing lazy + arrays when called with a ``min_count`` parameter (:issue:`4898`, :pull:`4911`). + By `Blair Bonnett `_. +- Fix bug preventing the ``min_count`` parameter to :py:meth:`DataArray.sum` and + :py:meth:`DataArray.prod` working correctly when calculating over all axes of + a float64 array (:issue:`4898`, :pull:`4911`). + By `Blair Bonnett `_. +- Fix decoding of vlen strings using h5py versions greater than 3.0.0 with h5netcdf backend (:issue:`4570`, :pull:`4893`). + By `Kai Mühlbauer `_. +- Allow converting :py:class:`Dataset` or :py:class:`DataArray` objects with a ``MultiIndex`` + and at least one other dimension to a ``pandas`` object (:issue:`3008`, :pull:`4442`). + By `ghislainp `_. Documentation ~~~~~~~~~~~~~ -- add information about requirements for accessor classes (:issue:`2788`, :pull:`4657`). +- Add information about requirements for accessor classes (:issue:`2788`, :pull:`4657`). By `Justus Magin `_. -- start a list of external I/O integrating with ``xarray`` (:issue:`683`, :pull:`4566`). +- Start a list of external I/O integrating with ``xarray`` (:issue:`683`, :pull:`4566`). By `Justus Magin `_. -- add concat examples and improve combining documentation (:issue:`4620`, :pull:`4645`). +- Add concat examples and improve combining documentation (:issue:`4620`, :pull:`4645`). By `Ray Bell `_ and `Justus Magin `_. +- explicitly mention that :py:meth:`Dataset.update` updates inplace (:issue:`2951`, :pull:`4932`). + By `Justus Magin `_. +- Added docs on vectorized indexing (:pull:`4711`). + By `Eric Keenan `_. Internal Changes ~~~~~~~~~~~~~~~~ @@ -121,32 +676,56 @@ Internal Changes - Run the tests in parallel using pytest-xdist (:pull:`4694`). By `Justus Magin `_ and `Mathias Hauser `_. - +- Use ``pyproject.toml`` instead of the ``setup_requires`` option for + ``setuptools`` (:pull:`4897`). + By `Justus Magin `_. - Replace all usages of ``assert x.identical(y)`` with ``assert_identical(x, y)`` - for clearer error messages. - (:pull:`4752`); + for clearer error messages (:pull:`4752`). By `Maximilian Roos `_. -- Speed up attribute style access (e.g. ``ds.somevar`` instead of ``ds["somevar"]``) and tab completion - in ipython (:issue:`4741`, :pull:`4742`). By `Richard Kleijn `_. -- Added the ``set_close`` method to ``Dataset`` and ``DataArray`` for beckends to specify how to voluntary release - all resources. (:pull:`#4809`), By `Alessandro Amici `_. +- Speed up attribute style access (e.g. ``ds.somevar`` instead of ``ds["somevar"]``) and + tab completion in IPython (:issue:`4741`, :pull:`4742`). + By `Richard Kleijn `_. +- Added the ``set_close`` method to ``Dataset`` and ``DataArray`` for backends + to specify how to voluntary release all resources. (:pull:`#4809`) + By `Alessandro Amici `_. +- Update type hints to work with numpy v1.20 (:pull:`4878`). + By `Mathias Hauser `_. +- Ensure warnings cannot be turned into exceptions in :py:func:`testing.assert_equal` and + the other ``assert_*`` functions (:pull:`4864`). + By `Mathias Hauser `_. +- Performance improvement when constructing DataArrays. Significantly speeds up + repr for Datasets with large number of variables. + By `Deepak Cherian `_. .. _whats-new.0.16.2: v0.16.2 (30 Nov 2020) --------------------- -This release brings the ability to write to limited regions of ``zarr`` files, open zarr files with :py:func:`open_dataset` and :py:func:`open_mfdataset`, increased support for propagating ``attrs`` using the ``keep_attrs`` flag, as well as numerous bugfixes and documentation improvements. - -Many thanks to the 31 contributors who contributed to this release: -Aaron Spring, Akio Taniguchi, Aleksandar Jelenak, alexamici, Alexandre Poux, Anderson Banihirwe, Andrew Pauling, Ashwin Vishnu, aurghs, Brian Ward, Caleb, crusaderky, Dan Nowacki, darikg, David Brochart, David Huard, Deepak Cherian, Dion Häfner, Gerardo Rivera, Gerrit Holl, Illviljan, inakleinbottle, Jacob Tomlinson, James A. Bednar, jenssss, Joe Hamman, johnomotani, Joris Van den Bossche, Julia Kent, Julius Busecke, Kai Mühlbauer, keewis, Keisuke Fujii, Kyle Cranmer, Luke Volpatti, Mathias Hauser, Maximilian Roos, Michaël Defferrard, Michal Baumgartner, Nick R. Papior, Pascal Bourgault, Peter Hausamann, PGijsbers, Ray Bell, Romain Martinez, rpgoldman, Russell Manser, Sahid Velji, Samnan Rahee, Sander, Spencer Clark, Stephan Hoyer, Thomas Zilio, Tobias Kölling, Tom Augspurger, Wei Ji, Yash Saboo, Zeb Nicholls, +This release brings the ability to write to limited regions of ``zarr`` files, +open zarr files with :py:func:`open_dataset` and :py:func:`open_mfdataset`, +increased support for propagating ``attrs`` using the ``keep_attrs`` flag, as +well as numerous bugfixes and documentation improvements. + +Many thanks to the 31 contributors who contributed to this release: Aaron +Spring, Akio Taniguchi, Aleksandar Jelenak, alexamici, Alexandre Poux, Anderson +Banihirwe, Andrew Pauling, Ashwin Vishnu, aurghs, Brian Ward, Caleb, crusaderky, +Dan Nowacki, darikg, David Brochart, David Huard, Deepak Cherian, Dion Häfner, +Gerardo Rivera, Gerrit Holl, Illviljan, inakleinbottle, Jacob Tomlinson, James +A. Bednar, jenssss, Joe Hamman, johnomotani, Joris Van den Bossche, Julia Kent, +Julius Busecke, Kai Mühlbauer, keewis, Keisuke Fujii, Kyle Cranmer, Luke +Volpatti, Mathias Hauser, Maximilian Roos, Michaël Defferrard, Michal +Baumgartner, Nick R. Papior, Pascal Bourgault, Peter Hausamann, PGijsbers, Ray +Bell, Romain Martinez, rpgoldman, Russell Manser, Sahid Velji, Samnan Rahee, +Sander, Spencer Clark, Stephan Hoyer, Thomas Zilio, Tobias Kölling, Tom +Augspurger, Wei Ji, Yash Saboo, Zeb Nicholls, Deprecations ~~~~~~~~~~~~ - :py:attr:`~core.accessor_dt.DatetimeAccessor.weekofyear` and :py:attr:`~core.accessor_dt.DatetimeAccessor.week` have been deprecated. Use ``DataArray.dt.isocalendar().week`` - instead (:pull:`4534`). By `Mathias Hauser `_, + instead (:pull:`4534`). By `Mathias Hauser `_. `Maximilian Roos `_, and `Spencer Clark `_. - :py:attr:`DataArray.rolling` and :py:attr:`Dataset.rolling` no longer support passing ``keep_attrs`` via its constructor. Pass ``keep_attrs`` via the applied function, i.e. use @@ -226,7 +805,7 @@ Documentation By `Pieter Gijsbers `_. - Fix grammar and typos in the :doc:`contributing` guide (:pull:`4545`). By `Sahid Velji `_. -- Fix grammar and typos in the :doc:`io` guide (:pull:`4553`). +- Fix grammar and typos in the :doc:`user-guide/io` guide (:pull:`4553`). By `Sahid Velji `_. - Update link to NumPy docstring standard in the :doc:`contributing` guide (:pull:`4558`). By `Sahid Velji `_. @@ -318,14 +897,13 @@ New Features By `Aaron Spring `_. - Use a wrapped array's ``_repr_inline_`` method to construct the collapsed ``repr`` of :py:class:`DataArray` and :py:class:`Dataset` objects and - document the new method in :doc:`internals`. (:pull:`4248`). + document the new method in :doc:`internals/index`. (:pull:`4248`). By `Justus Magin `_. - Allow per-variable fill values in most functions. (:pull:`4237`). By `Justus Magin `_. - Expose ``use_cftime`` option in :py:func:`~xarray.open_zarr` (:issue:`2886`, :pull:`3229`) By `Samnan Rahee `_ and `Anderson Banihirwe `_. - Bug fixes ~~~~~~~~~ @@ -785,7 +1363,7 @@ Internal Changes v0.15.0 (30 Jan 2020) --------------------- -This release brings many improvements to xarray's documentation: our examples are now binderized notebooks (`click here `_) +This release brings many improvements to xarray's documentation: our examples are now binderized notebooks (`click here `_) and we have new example notebooks from our SciPy 2019 sprint (many thanks to our contributors!). This release also features many API improvements such as a new @@ -2322,7 +2900,7 @@ non-standard calendars used in climate modeling. Documentation ~~~~~~~~~~~~~ -- New FAQ entry, :ref:`related-projects`. +- New FAQ entry, :ref:`ecosystem`. By `Deepak Cherian `_. - :ref:`assigning_values` now includes examples on how to select and assign values to a :py:class:`~xarray.DataArray` with ``.loc``. @@ -2536,7 +3114,7 @@ Documentation - Added apply_ufunc example to :ref:`/examples/weather-data.ipynb#Toy-weather-data` (:issue:`1844`). By `Liam Brannigan `_. - New entry `Why don’t aggregations return Python scalars?` in the - :doc:`faq` (:issue:`1726`). + :doc:`getting-started-guide/faq` (:issue:`1726`). By `0x0L `_. Enhancements @@ -4367,11 +4945,11 @@ Highlights ~~~~~~~~~~ The headline feature in this release is experimental support for out-of-core -computing (data that doesn't fit into memory) with dask_. This includes a new +computing (data that doesn't fit into memory) with :doc:`user-guide/dask`. This includes a new top-level function ``xray.open_mfdataset`` that makes it easy to open a collection of netCDF (using dask) as a single ``xray.Dataset`` object. For more on dask, read the `blog post introducing xray + dask`_ and the new -documentation section :doc:`dask`. +documentation section :doc:`user-guide/dask`. .. _blog post introducing xray + dask: https://www.anaconda.com/blog/developer-blog/xray-dask-out-core-labeled-arrays-python/ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000000..f1f1a2ac8a6 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,11 @@ +[build-system] +requires = [ + "setuptools>=42", + "wheel", + "setuptools_scm[toml]>=3.4", + "setuptools_scm_git_archive", +] +build-backend = "setuptools.build_meta" + +[tool.setuptools_scm] +fallback_version = "999" diff --git a/setup.cfg b/setup.cfg index a695191bf02..c44d207bf0f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -66,6 +66,8 @@ classifiers = Programming Language :: Python :: 3 Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 Topic :: Scientific/Engineering [options] @@ -74,13 +76,9 @@ zip_safe = False # https://mypy.readthedocs.io/en/latest/installed_packages.htm include_package_data = True python_requires = >=3.7 install_requires = - numpy >= 1.15 - pandas >= 0.25 + numpy >= 1.17 + pandas >= 1.0 setuptools >= 40.4 # For pkg_resources -setup_requires = - setuptools >= 40.4 - setuptools_scm - [options.extras_require] io = @@ -93,6 +91,7 @@ io = cftime rasterio cfgrib + pooch ## Scitools packages & dependencies (e.g: cartopy, cf-units) can be hard to install # scitools-iris @@ -153,6 +152,8 @@ ignore = E501 # line too long - let black worry about that E731 # do not assign a lambda expression, use a def W503 # line break before binary operator +per-file-ignores = + xarray/tests/*.py:F401,F811 exclude= .eggs doc @@ -164,11 +165,18 @@ force_to_top = true default_section = THIRDPARTY known_first_party = xarray +[mypy] +exclude = properties|asv_bench|doc +files = xarray/**/*.py +show_error_codes = True + # Most of the numerical computing stack doesn't have type annotations yet. [mypy-affine.*] ignore_missing_imports = True [mypy-bottleneck.*] ignore_missing_imports = True +[mypy-cartopy.*] +ignore_missing_imports = True [mypy-cdms2.*] ignore_missing_imports = True [mypy-cf_units.*] @@ -183,6 +191,8 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-distributed.*] ignore_missing_imports = True +[mypy-fsspec.*] +ignore_missing_imports = True [mypy-h5netcdf.*] ignore_missing_imports = True [mypy-h5py.*] @@ -207,6 +217,8 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-pint.*] ignore_missing_imports = True +[mypy-pooch.*] +ignore_missing_imports = True [mypy-PseudoNetCDF.*] ignore_missing_imports = True [mypy-pydap.*] diff --git a/xarray/__init__.py b/xarray/__init__.py index 3886edc60e6..8321aba4b46 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -18,12 +18,12 @@ from .core.alignment import align, broadcast from .core.combine import combine_by_coords, combine_nested from .core.common import ALL_DIMS, full_like, ones_like, zeros_like -from .core.computation import apply_ufunc, corr, cov, dot, polyval, where +from .core.computation import apply_ufunc, corr, cov, dot, polyval, unify_chunks, where from .core.concat import concat from .core.dataarray import DataArray from .core.dataset import Dataset from .core.extensions import register_dataarray_accessor, register_dataset_accessor -from .core.merge import MergeError, merge +from .core.merge import Context, MergeError, merge from .core.options import set_options from .core.parallel import map_blocks from .core.variable import Coordinate, IndexVariable, Variable, as_variable @@ -74,10 +74,12 @@ "save_mfdataset", "set_options", "show_versions", + "unify_chunks", "where", "zeros_like", # Classes "CFTimeIndex", + "Context", "Coordinate", "DataArray", "Dataset", diff --git a/xarray/backends/__init__.py b/xarray/backends/__init__.py index 1500ea5061f..2ebf7a4244b 100644 --- a/xarray/backends/__init__.py +++ b/xarray/backends/__init__.py @@ -4,7 +4,7 @@ formats. They should not be used directly, but rather through Dataset objects. """ from .cfgrib_ import CfGribDataStore -from .common import AbstractDataStore +from .common import AbstractDataStore, BackendArray, BackendEntrypoint from .file_manager import CachingFileManager, DummyFileManager, FileManager from .h5netcdf_ import H5NetCDFStore from .memory import InMemoryDataStore @@ -18,6 +18,8 @@ __all__ = [ "AbstractDataStore", + "BackendArray", + "BackendEntrypoint", "FileManager", "CachingFileManager", "CfGribDataStore", diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 81314588784..9b4fa8fce5a 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1,4 +1,5 @@ import os +import warnings from glob import glob from io import BytesIO from numbers import Number @@ -11,6 +12,7 @@ Iterable, Mapping, MutableMapping, + Optional, Tuple, Union, ) @@ -26,8 +28,9 @@ ) from ..core.dataarray import DataArray from ..core.dataset import Dataset, _get_chunk, _maybe_chunk -from ..core.utils import close_on_error, is_grib_path, is_remote_uri, read_magic_number -from .common import AbstractDataStore, ArrayWriter +from ..core.utils import is_remote_uri +from . import plugins +from .common import AbstractDataStore, ArrayWriter, _normalize_path from .locks import _get_scheduler if TYPE_CHECKING: @@ -70,26 +73,6 @@ def _get_default_engine_remote_uri(): return engine -def _get_default_engine_grib(): - msgs = [] - try: - import Nio # noqa: F401 - - msgs += ["set engine='pynio' to access GRIB files with PyNIO"] - except ImportError: # pragma: no cover - pass - try: - import cfgrib # noqa: F401 - - msgs += ["set engine='cfgrib' to access GRIB files with cfgrib"] - except ImportError: # pragma: no cover - pass - if msgs: - raise ValueError(" or\n".join(msgs)) - else: - raise ValueError("PyNIO or cfgrib is required for accessing GRIB files") - - def _get_default_engine_gz(): try: import scipy # noqa: F401 @@ -118,63 +101,13 @@ def _get_default_engine_netcdf(): return engine -def _get_engine_from_magic_number(filename_or_obj): - magic_number = read_magic_number(filename_or_obj) - - if magic_number.startswith(b"CDF"): - engine = "scipy" - elif magic_number.startswith(b"\211HDF\r\n\032\n"): - engine = "h5netcdf" - else: - raise ValueError( - "cannot guess the engine, " - f"{magic_number} is not the signature of any supported file format " - "did you mean to pass a string for a path instead?" - ) - return engine - - def _get_default_engine(path: str, allow_remote: bool = False): if allow_remote and is_remote_uri(path): - engine = _get_default_engine_remote_uri() - elif is_grib_path(path): - engine = _get_default_engine_grib() + return _get_default_engine_remote_uri() elif path.endswith(".gz"): - engine = _get_default_engine_gz() + return _get_default_engine_gz() else: - engine = _get_default_engine_netcdf() - return engine - - -def _autodetect_engine(filename_or_obj): - if isinstance(filename_or_obj, AbstractDataStore): - engine = "store" - elif isinstance(filename_or_obj, (str, Path)): - engine = _get_default_engine(str(filename_or_obj), allow_remote=True) - else: - engine = _get_engine_from_magic_number(filename_or_obj) - return engine - - -def _get_backend_cls(engine, engines=ENGINES): - """Select open_dataset method based on current engine""" - try: - return engines[engine] - except KeyError: - raise ValueError( - "unrecognized engine for open_dataset: {}\n" - "must be one of: {}".format(engine, list(ENGINES)) - ) - - -def _normalize_path(path): - if isinstance(path, Path): - path = str(path) - - if isinstance(path, str) and not is_remote_uri(path): - path = os.path.abspath(os.path.expanduser(path)) - - return path + return _get_default_engine_netcdf() def _validate_dataset_names(dataset): @@ -199,12 +132,21 @@ def check_name(name): check_name(k) -def _validate_attrs(dataset): +def _validate_attrs(dataset, invalid_netcdf=False): """`attrs` must have a string key and a value which is either: a number, - a string, an ndarray or a list/tuple of numbers/strings. + a string, an ndarray, a list/tuple of numbers/strings, or a numpy.bool_. + + Notes + ----- + A numpy.bool_ is only allowed when using the h5netcdf engine with + `invalid_netcdf=True`. """ - def check_attr(name, value): + valid_types = (str, Number, np.ndarray, np.number, list, tuple) + if invalid_netcdf: + valid_types += (np.bool_,) + + def check_attr(name, value, valid_types): if isinstance(name, str): if not name: raise ValueError( @@ -218,22 +160,46 @@ def check_attr(name, value): "serialization to netCDF files" ) - if not isinstance(value, (str, Number, np.ndarray, np.number, list, tuple)): + if not isinstance(value, valid_types): raise TypeError( - f"Invalid value for attr {name!r}: {value!r} must be a number, " - "a string, an ndarray or a list/tuple of " - "numbers/strings for serialization to netCDF " - "files" + f"Invalid value for attr {name!r}: {value!r}. For serialization to " + "netCDF files, its value must be of one of the following types: " + f"{', '.join([vtype.__name__ for vtype in valid_types])}" ) # Check attrs on the dataset itself for k, v in dataset.attrs.items(): - check_attr(k, v) + check_attr(k, v, valid_types) # Check attrs on each variable within the dataset for variable in dataset.variables.values(): for k, v in variable.attrs.items(): - check_attr(k, v) + check_attr(k, v, valid_types) + + +def _resolve_decoders_kwargs(decode_cf, open_backend_dataset_parameters, **decoders): + for d in list(decoders): + if decode_cf is False and d in open_backend_dataset_parameters: + decoders[d] = False + if decoders[d] is None: + decoders.pop(d) + return decoders + + +def _get_mtime(filename_or_obj): + # if passed an actual file path, augment the token with + # the file modification time + mtime = None + + try: + path = os.fspath(filename_or_obj) + except TypeError: + path = None + + if path and not is_remote_uri(path): + mtime = os.path.getmtime(filename_or_obj) + + return mtime def _protect_dataset_variables_inplace(dataset, cache): @@ -247,7 +213,7 @@ def _protect_dataset_variables_inplace(dataset, cache): def _finalize_store(write, store): - """ Finalize this store by explicitly syncing and closing""" + """Finalize this store by explicitly syncing and closing""" del write # ensure writing is done first store.close() @@ -304,22 +270,86 @@ def load_dataarray(filename_or_obj, **kwargs): return da.load() +def _chunk_ds( + backend_ds, + filename_or_obj, + engine, + chunks, + overwrite_encoded_chunks, + **extra_tokens, +): + from dask.base import tokenize + + mtime = _get_mtime(filename_or_obj) + token = tokenize(filename_or_obj, mtime, engine, chunks, **extra_tokens) + name_prefix = f"open_dataset-{token}" + + variables = {} + for name, var in backend_ds.variables.items(): + var_chunks = _get_chunk(var, chunks) + variables[name] = _maybe_chunk( + name, + var, + var_chunks, + overwrite_encoded_chunks=overwrite_encoded_chunks, + name_prefix=name_prefix, + token=token, + ) + return backend_ds._replace(variables) + + +def _dataset_from_backend_dataset( + backend_ds, + filename_or_obj, + engine, + chunks, + cache, + overwrite_encoded_chunks, + **extra_tokens, +): + if not isinstance(chunks, (int, dict)) and chunks not in {None, "auto"}: + raise ValueError( + f"chunks must be an int, dict, 'auto', or None. Instead found {chunks}." + ) + + _protect_dataset_variables_inplace(backend_ds, cache) + if chunks is None: + ds = backend_ds + else: + ds = _chunk_ds( + backend_ds, + filename_or_obj, + engine, + chunks, + overwrite_encoded_chunks, + **extra_tokens, + ) + + ds.set_close(backend_ds._close) + + # Ensure source filename always stored in dataset object (GH issue #2550) + if "source" not in ds.encoding and isinstance(filename_or_obj, str): + ds.encoding["source"] = filename_or_obj + + return ds + + def open_dataset( filename_or_obj, - group=None, - decode_cf=True, - mask_and_scale=None, - decode_times=True, - concat_characters=True, - decode_coords=True, + *args, engine=None, chunks=None, - lock=None, cache=None, + decode_cf=None, + mask_and_scale=None, + decode_times=None, + decode_timedelta=None, + use_cftime=None, + concat_characters=None, + decode_coords=None, drop_variables=None, backend_kwargs=None, - use_cftime=None, - decode_timedelta=None, + **kwargs, ): """Open and decode a dataset from a file or file-like object. @@ -331,50 +361,20 @@ def open_dataset( ends with .gz, in which case the file is gunzipped and opened with scipy.io.netcdf (only netCDF3 supported). Byte-strings or file-like objects are opened by scipy.io.netcdf (netCDF3) or h5py (netCDF4/HDF). - group : str, optional - Path to the netCDF4 group in the given file to open (only works for - netCDF4 files). - decode_cf : bool, optional - Whether to decode these variables, assuming they were saved according - to CF conventions. - mask_and_scale : bool, optional - If True, replace array values equal to `_FillValue` with NA and scale - values according to the formula `original_values * scale_factor + - add_offset`, where `_FillValue`, `scale_factor` and `add_offset` are - taken from variable attributes (if they exist). If the `_FillValue` or - `missing_value` attribute contains multiple values a warning will be - issued and all array values matching one of the multiple values will - be replaced by NA. mask_and_scale defaults to True except for the - pseudonetcdf backend. - decode_times : bool, optional - If True, decode times encoded in the standard NetCDF datetime format - into datetime objects. Otherwise, leave them encoded as numbers. - concat_characters : bool, optional - If True, concatenate along the last dimension of character arrays to - form string arrays. Dimensions will only be concatenated over (and - removed) if they have no corresponding variable and if they are only - used as the last dimension of character arrays. - decode_coords : bool, optional - If True, decode the 'coordinates' attribute to identify coordinates in - the resulting dataset. engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "pynio", "cfgrib", \ - "pseudonetcdf", "zarr"}, optional + "pseudonetcdf", "zarr"} or subclass of xarray.backends.BackendEntrypoint, optional Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for - "netcdf4". + "netcdf4". A custom backend class (a subclass of ``BackendEntrypoint``) + can also be used. chunks : int or dict, optional If chunks is provided, it is used to load the new dataset into dask arrays. ``chunks=-1`` loads the dataset with dask using a single - chunk for all arrays. `chunks={}`` loads the dataset with dask using + chunk for all arrays. ``chunks={}`` loads the dataset with dask using engine preferred chunks if exposed by the backend, otherwise with a single chunk for all arrays. ``chunks='auto'`` will use dask ``auto`` chunking taking into account the engine preferred chunks. See dask chunking for more details. - lock : False or lock-like, optional - Resource lock to use when reading data from disk. Only relevant when - using dask or another form of parallelism. By default, appropriate - locks are chosen to safely read and write files with the currently - active dask scheduler. cache : bool, optional If True, cache data loaded from the underlying datastore in memory as NumPy arrays when accessed to avoid reading from the underlying data- @@ -382,14 +382,28 @@ def open_dataset( argument to use dask, in which case it defaults to False. Does not change the behavior of coordinates corresponding to dimensions, which always load their data from disk into a ``pandas.Index``. - drop_variables: str or iterable, optional - A variable or list of variables to exclude from being parsed from the - dataset. This may be useful to drop variables with problems or - inconsistent values. - backend_kwargs: dict, optional - A dictionary of keyword arguments to pass on to the backend. This - may be useful when backend options would improve performance or - allow user control of dataset processing. + decode_cf : bool, optional + Whether to decode these variables, assuming they were saved according + to CF conventions. + mask_and_scale : bool, optional + If True, replace array values equal to `_FillValue` with NA and scale + values according to the formula `original_values * scale_factor + + add_offset`, where `_FillValue`, `scale_factor` and `add_offset` are + taken from variable attributes (if they exist). If the `_FillValue` or + `missing_value` attribute contains multiple values a warning will be + issued and all array values matching one of the multiple values will + be replaced by NA. mask_and_scale defaults to True except for the + pseudonetcdf backend. This keyword may not be supported by all the backends. + decode_times : bool, optional + If True, decode times encoded in the standard NetCDF datetime format + into datetime objects. Otherwise, leave them encoded as numbers. + This keyword may not be supported by all the backends. + decode_timedelta : bool, optional + If True, decode variables and coordinates with time units in + {"days", "hours", "minutes", "seconds", "milliseconds", "microseconds"} + into timedelta objects. If False, leave them encoded as numbers. + If None (default), assume the same value of decode_time. + This keyword may not be supported by all the backends. use_cftime: bool, optional Only relevant if encoded dates come from a standard calendar (e.g. "gregorian", "proleptic_gregorian", "standard", or not @@ -399,12 +413,41 @@ def open_dataset( ``cftime.datetime`` objects, regardless of whether or not they can be represented using ``np.datetime64[ns]`` objects. If False, always decode times to ``np.datetime64[ns]`` objects; if this is not possible - raise an error. - decode_timedelta : bool, optional - If True, decode variables and coordinates with time units in - {"days", "hours", "minutes", "seconds", "milliseconds", "microseconds"} - into timedelta objects. If False, leave them encoded as numbers. - If None (default), assume the same value of decode_time. + raise an error. This keyword may not be supported by all the backends. + concat_characters : bool, optional + If True, concatenate along the last dimension of character arrays to + form string arrays. Dimensions will only be concatenated over (and + removed) if they have no corresponding variable and if they are only + used as the last dimension of character arrays. + This keyword may not be supported by all the backends. + decode_coords : bool or {"coordinates", "all"}, optional + Controls which variables are set as coordinate variables: + + - "coordinates" or True: Set variables referred to in the + ``'coordinates'`` attribute of the datasets or individual variables + as coordinate variables. + - "all": Set variables referred to in ``'grid_mapping'``, ``'bounds'`` and + other attributes as coordinate variables. + drop_variables: str or iterable, optional + A variable or list of variables to exclude from being parsed from the + dataset. This may be useful to drop variables with problems or + inconsistent values. + backend_kwargs: dict + Additional keyword arguments passed on to the engine open function, + equivalent to `**kwargs`. + **kwargs: dict + Additional keyword arguments passed on to the engine open function. + For example: + + - 'group': path to the netCDF4 group in the given file to open given as + a str,supported by "netcdf4", "h5netcdf", "zarr". + - 'lock': resource lock to use when reading data from disk. Only + relevant when using dask or another form of parallelism. By default, + appropriate locks are chosen to safely read and write files with the + currently active dask scheduler. Supported by "netcdf4", "h5netcdf", + "scipy", "pynio", "pseudonetcdf", "cfgrib". + + See engine open function for kwargs accepted by each specific engine. Returns ------- @@ -422,159 +465,71 @@ def open_dataset( -------- open_mfdataset """ - if os.environ.get("XARRAY_BACKEND_API", "v1") == "v2": - kwargs = {k: v for k, v in locals().items() if v is not None} - from . import apiv2 - - return apiv2.open_dataset(**kwargs) - - if mask_and_scale is None: - mask_and_scale = not engine == "pseudonetcdf" - - if not decode_cf: - mask_and_scale = False - decode_times = False - concat_characters = False - decode_coords = False - decode_timedelta = False + if len(args) > 0: + raise TypeError( + "open_dataset() takes only 1 positional argument starting from version 0.18.0, " + "all other options must be passed as keyword arguments" + ) if cache is None: cache = chunks is None - if backend_kwargs is None: - backend_kwargs = {} - - def maybe_decode_store(store, chunks): - ds = conventions.decode_cf( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, - ) - - _protect_dataset_variables_inplace(ds, cache) - - if chunks is not None and engine != "zarr": - from dask.base import tokenize - - # if passed an actual file path, augment the token with - # the file modification time - if isinstance(filename_or_obj, str) and not is_remote_uri(filename_or_obj): - mtime = os.path.getmtime(filename_or_obj) - else: - mtime = None - token = tokenize( - filename_or_obj, - mtime, - group, - decode_cf, - mask_and_scale, - decode_times, - concat_characters, - decode_coords, - engine, - chunks, - drop_variables, - use_cftime, - decode_timedelta, - ) - name_prefix = "open_dataset-%s" % token - ds2 = ds.chunk(chunks, name_prefix=name_prefix, token=token) - - elif engine == "zarr": - # adapted from Dataset.Chunk() and taken from open_zarr - if not (isinstance(chunks, (int, dict)) or chunks is None): - if chunks != "auto": - raise ValueError( - "chunks must be an int, dict, 'auto', or None. " - "Instead found %s. " % chunks - ) - - if chunks == "auto": - try: - import dask.array # noqa - except ImportError: - chunks = None - - # auto chunking needs to be here and not in ZarrStore because - # the variable chunks does not survive decode_cf - # return trivial case - if chunks is None: - return ds - - if isinstance(chunks, int): - chunks = dict.fromkeys(ds.dims, chunks) - - variables = { - k: _maybe_chunk( - k, - v, - _get_chunk(v, chunks), - overwrite_encoded_chunks=overwrite_encoded_chunks, - ) - for k, v in ds.variables.items() - } - ds2 = ds._replace(variables) - - else: - ds2 = ds - ds2.set_close(ds._close) - return ds2 + if backend_kwargs is not None: + kwargs.update(backend_kwargs) - filename_or_obj = _normalize_path(filename_or_obj) + if engine is None: + engine = plugins.guess_engine(filename_or_obj) - if isinstance(filename_or_obj, AbstractDataStore): - store = filename_or_obj - else: - if engine is None: - engine = _autodetect_engine(filename_or_obj) - - extra_kwargs = {} - if group is not None: - extra_kwargs["group"] = group - if lock is not None: - extra_kwargs["lock"] = lock - - if engine == "zarr": - backend_kwargs = backend_kwargs.copy() - overwrite_encoded_chunks = backend_kwargs.pop( - "overwrite_encoded_chunks", None - ) + backend = plugins.get_backend(engine) - opener = _get_backend_cls(engine) - store = opener(filename_or_obj, **extra_kwargs, **backend_kwargs) - - with close_on_error(store): - ds = maybe_decode_store(store, chunks) - - # Ensure source filename always stored in dataset object (GH issue #2550) - if "source" not in ds.encoding: - if isinstance(filename_or_obj, str): - ds.encoding["source"] = filename_or_obj + decoders = _resolve_decoders_kwargs( + decode_cf, + open_backend_dataset_parameters=backend.open_dataset_parameters, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + decode_timedelta=decode_timedelta, + concat_characters=concat_characters, + use_cftime=use_cftime, + decode_coords=decode_coords, + ) + overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None) + backend_ds = backend.open_dataset( + filename_or_obj, + drop_variables=drop_variables, + **decoders, + **kwargs, + ) + ds = _dataset_from_backend_dataset( + backend_ds, + filename_or_obj, + engine, + chunks, + cache, + overwrite_encoded_chunks, + drop_variables=drop_variables, + **decoders, + **kwargs, + ) return ds def open_dataarray( filename_or_obj, - group=None, - decode_cf=True, - mask_and_scale=None, - decode_times=True, - concat_characters=True, - decode_coords=True, + *args, engine=None, chunks=None, - lock=None, cache=None, + decode_cf=None, + mask_and_scale=None, + decode_times=None, + decode_timedelta=None, + use_cftime=None, + concat_characters=None, + decode_coords=None, drop_variables=None, backend_kwargs=None, - use_cftime=None, - decode_timedelta=None, + **kwargs, ): """Open an DataArray from a file or file-like object containing a single data variable. @@ -585,14 +540,31 @@ def open_dataarray( Parameters ---------- filename_or_obj : str, Path, file-like or DataStore - Strings and Paths are interpreted as a path to a netCDF file or an - OpenDAP URL and opened with python-netCDF4, unless the filename ends - with .gz, in which case the file is gunzipped and opened with + Strings and Path objects are interpreted as a path to a netCDF file + or an OpenDAP URL and opened with python-netCDF4, unless the filename + ends with .gz, in which case the file is gunzipped and opened with scipy.io.netcdf (only netCDF3 supported). Byte-strings or file-like objects are opened by scipy.io.netcdf (netCDF3) or h5py (netCDF4/HDF). - group : str, optional - Path to the netCDF4 group in the given file to open (only works for - netCDF4 files). + engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "pynio", "cfgrib", \ + "pseudonetcdf", "zarr"}, optional + Engine to use when reading files. If not provided, the default engine + is chosen based on available dependencies, with a preference for + "netcdf4". + chunks : int or dict, optional + If chunks is provided, it is used to load the new dataset into dask + arrays. ``chunks=-1`` loads the dataset with dask using a single + chunk for all arrays. `chunks={}`` loads the dataset with dask using + engine preferred chunks if exposed by the backend, otherwise with + a single chunk for all arrays. + ``chunks='auto'`` will use dask ``auto`` chunking taking into account the + engine preferred chunks. See dask chunking for more details. + cache : bool, optional + If True, cache data loaded from the underlying datastore in memory as + NumPy arrays when accessed to avoid reading from the underlying data- + store multiple times. Defaults to True unless you specify the `chunks` + argument to use dask, in which case it defaults to False. Does not + change the behavior of coordinates corresponding to dimensions, which + always load their data from disk into a ``pandas.Index``. decode_cf : bool, optional Whether to decode these variables, assuming they were saved according to CF conventions. @@ -604,46 +576,17 @@ def open_dataarray( `missing_value` attribute contains multiple values a warning will be issued and all array values matching one of the multiple values will be replaced by NA. mask_and_scale defaults to True except for the - pseudonetcdf backend. + pseudonetcdf backend. This keyword may not be supported by all the backends. decode_times : bool, optional If True, decode times encoded in the standard NetCDF datetime format into datetime objects. Otherwise, leave them encoded as numbers. - concat_characters : bool, optional - If True, concatenate along the last dimension of character arrays to - form string arrays. Dimensions will only be concatenated over (and - removed) if they have no corresponding variable and if they are only - used as the last dimension of character arrays. - decode_coords : bool, optional - If True, decode the 'coordinates' attribute to identify coordinates in - the resulting dataset. - engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "pynio", "cfgrib"}, \ - optional - Engine to use when reading files. If not provided, the default engine - is chosen based on available dependencies, with a preference for - "netcdf4". - chunks : int or dict, optional - If chunks is provided, it used to load the new dataset into dask - arrays. - lock : False or lock-like, optional - Resource lock to use when reading data from disk. Only relevant when - using dask or another form of parallelism. By default, appropriate - locks are chosen to safely read and write files with the currently - active dask scheduler. - cache : bool, optional - If True, cache data loaded from the underlying datastore in memory as - NumPy arrays when accessed to avoid reading from the underlying data- - store multiple times. Defaults to True unless you specify the `chunks` - argument to use dask, in which case it defaults to False. Does not - change the behavior of coordinates corresponding to dimensions, which - always load their data from disk into a ``pandas.Index``. - drop_variables: str or iterable, optional - A variable or list of variables to exclude from being parsed from the - dataset. This may be useful to drop variables with problems or - inconsistent values. - backend_kwargs: dict, optional - A dictionary of keyword arguments to pass on to the backend. This - may be useful when backend options would improve performance or - allow user control of dataset processing. + This keyword may not be supported by all the backends. + decode_timedelta : bool, optional + If True, decode variables and coordinates with time units in + {"days", "hours", "minutes", "seconds", "milliseconds", "microseconds"} + into timedelta objects. If False, leave them encoded as numbers. + If None (default), assume the same value of decode_time. + This keyword may not be supported by all the backends. use_cftime: bool, optional Only relevant if encoded dates come from a standard calendar (e.g. "gregorian", "proleptic_gregorian", "standard", or not @@ -653,12 +596,41 @@ def open_dataarray( ``cftime.datetime`` objects, regardless of whether or not they can be represented using ``np.datetime64[ns]`` objects. If False, always decode times to ``np.datetime64[ns]`` objects; if this is not possible - raise an error. - decode_timedelta : bool, optional - If True, decode variables and coordinates with time units in - {"days", "hours", "minutes", "seconds", "milliseconds", "microseconds"} - into timedelta objects. If False, leave them encoded as numbers. - If None (default), assume the same value of decode_time. + raise an error. This keyword may not be supported by all the backends. + concat_characters : bool, optional + If True, concatenate along the last dimension of character arrays to + form string arrays. Dimensions will only be concatenated over (and + removed) if they have no corresponding variable and if they are only + used as the last dimension of character arrays. + This keyword may not be supported by all the backends. + decode_coords : bool or {"coordinates", "all"}, optional + Controls which variables are set as coordinate variables: + + - "coordinates" or True: Set variables referred to in the + ``'coordinates'`` attribute of the datasets or individual variables + as coordinate variables. + - "all": Set variables referred to in ``'grid_mapping'``, ``'bounds'`` and + other attributes as coordinate variables. + drop_variables: str or iterable, optional + A variable or list of variables to exclude from being parsed from the + dataset. This may be useful to drop variables with problems or + inconsistent values. + backend_kwargs: dict + Additional keyword arguments passed on to the engine open function, + equivalent to `**kwargs`. + **kwargs: dict + Additional keyword arguments passed on to the engine open function. + For example: + + - 'group': path to the netCDF4 group in the given file to open given as + a str,supported by "netcdf4", "h5netcdf", "zarr". + - 'lock': resource lock to use when reading data from disk. Only + relevant when using dask or another form of parallelism. By default, + appropriate locks are chosen to safely read and write files with the + currently active dask scheduler. Supported by "netcdf4", "h5netcdf", + "scipy", "pynio", "pseudonetcdf", "cfgrib". + + See engine open function for kwargs accepted by each specific engine. Notes ----- @@ -673,10 +645,14 @@ def open_dataarray( -------- open_dataset """ + if len(args) > 0: + raise TypeError( + "open_dataarray() takes only 1 positional argument starting from version 0.18.0, " + "all other options must be passed as keyword arguments" + ) dataset = open_dataset( filename_or_obj, - group=group, decode_cf=decode_cf, mask_and_scale=mask_and_scale, decode_times=decode_times, @@ -684,12 +660,12 @@ def open_dataarray( decode_coords=decode_coords, engine=engine, chunks=chunks, - lock=lock, cache=cache, drop_variables=drop_variables, backend_kwargs=backend_kwargs, use_cftime=use_cftime, decode_timedelta=decode_timedelta, + **kwargs, ) if len(dataset.data_vars) != 1: @@ -722,13 +698,13 @@ def open_mfdataset( compat="no_conflicts", preprocess=None, engine=None, - lock=None, data_vars="all", coords="different", combine="by_coords", parallel=False, join="outer", attrs_file=None, + combine_attrs="override", **kwargs, ): """Open multiple files as a single dataset. @@ -758,7 +734,7 @@ def open_mfdataset( see the full documentation for more details [2]_. concat_dim : str, or list of str, DataArray, Index or None, optional Dimensions to concatenate files along. You only need to provide this argument - if ``combine='by_coords'``, and if any of the dimensions along which you want to + if ``combine='nested'``, and if any of the dimensions along which you want to concatenate is not a dimension in the original datasets, e.g., if you want to stack a collection of 2D arrays along a third dimension. Set ``concat_dim=[..., None, ...]`` explicitly to disable concatenation along a @@ -792,11 +768,6 @@ def open_mfdataset( Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for "netcdf4". - lock : False or lock-like, optional - Resource lock to use when reading data from disk. Only relevant when - using dask or another form of parallelism. By default, appropriate - locks are chosen to safely read and write files with the currently - active dask scheduler. data_vars : {"minimal", "different", "all"} or list of str, optional These data variables will be concatenated together: * "minimal": Only data variables in which the dimension already @@ -869,30 +840,63 @@ def open_mfdataset( .. [2] http://xarray.pydata.org/en/stable/dask.html#chunking-and-performance """ if isinstance(paths, str): - if is_remote_uri(paths): + if is_remote_uri(paths) and engine == "zarr": + try: + from fsspec.core import get_fs_token_paths + except ImportError as e: + raise ImportError( + "The use of remote URLs for opening zarr requires the package fsspec" + ) from e + + fs, _, _ = get_fs_token_paths( + paths, + mode="rb", + storage_options=kwargs.get("backend_kwargs", {}).get( + "storage_options", {} + ), + expand=False, + ) + paths = fs.glob(fs._strip_protocol(paths)) # finds directories + paths = [fs.get_mapper(path) for path in paths] + elif is_remote_uri(paths): raise ValueError( "cannot do wild-card matching for paths that are remote URLs: " "{!r}. Instead, supply paths as an explicit list of strings.".format( paths ) ) - paths = sorted(glob(_normalize_path(paths))) + else: + paths = sorted(glob(_normalize_path(paths))) else: paths = [str(p) if isinstance(p, Path) else p for p in paths] if not paths: raise OSError("no files to open") - # If combine='by_coords' then this is unnecessary, but quick. - # If combine='nested' then this creates a flat list which is easier to - # iterate over, while saving the originally-supplied structure as "ids" if combine == "nested": if isinstance(concat_dim, (str, DataArray)) or concat_dim is None: concat_dim = [concat_dim] - combined_ids_paths = _infer_concat_order_from_positions(paths) - ids, paths = (list(combined_ids_paths.keys()), list(combined_ids_paths.values())) - open_kwargs = dict(engine=engine, chunks=chunks or {}, lock=lock, **kwargs) + # This creates a flat list which is easier to iterate over, whilst + # encoding the originally-supplied structure as "ids". + # The "ids" are not used at all if combine='by_coords`. + combined_ids_paths = _infer_concat_order_from_positions(paths) + ids, paths = ( + list(combined_ids_paths.keys()), + list(combined_ids_paths.values()), + ) + + # TODO raise an error instead of a warning after v0.19 + elif combine == "by_coords" and concat_dim is not None: + warnings.warn( + "When combine='by_coords', passing a value for `concat_dim` has no " + "effect. This combination will raise an error in future. To manually " + "combine along a specific dimension you should instead specify " + "combine='nested' along with a value for `concat_dim`.", + DeprecationWarning, + ) + + open_kwargs = dict(engine=engine, chunks=chunks or {}, **kwargs) if parallel: import dask @@ -929,7 +933,7 @@ def open_mfdataset( coords=coords, ids=ids, join=join, - combine_attrs="drop", + combine_attrs=combine_attrs, ) elif combine == "by_coords": # Redo ordering from coordinates, ignoring how they were ordered @@ -940,7 +944,7 @@ def open_mfdataset( data_vars=data_vars, coords=coords, join=join, - combine_attrs="drop", + combine_attrs=combine_attrs, ) else: raise ValueError( @@ -963,8 +967,6 @@ def multi_file_closer(): if isinstance(attrs_file, Path): attrs_file = str(attrs_file) combined.attrs = datasets[paths.index(attrs_file)].attrs - else: - combined.attrs = datasets[0].attrs return combined @@ -1008,8 +1010,8 @@ def to_netcdf( elif engine != "scipy": raise ValueError( "invalid engine for creating bytes with " - "to_netcdf: %r. Only the default engine " - "or engine='scipy' is supported" % engine + f"to_netcdf: {engine!r}. Only the default engine " + "or engine='scipy' is supported" ) if not compute: raise NotImplementedError( @@ -1025,12 +1027,12 @@ def to_netcdf( # validate Dataset keys, DataArray names, and attr keys/values _validate_dataset_names(dataset) - _validate_attrs(dataset) + _validate_attrs(dataset, invalid_netcdf=invalid_netcdf and engine == "h5netcdf") try: store_open = WRITEABLE_STORES[engine] except KeyError: - raise ValueError("unrecognized engine for to_netcdf: %r" % engine) + raise ValueError(f"unrecognized engine for to_netcdf: {engine!r}") if format is not None: format = format.upper() @@ -1042,9 +1044,8 @@ def to_netcdf( autoclose = have_chunks and scheduler in ["distributed", "multiprocessing"] if autoclose and engine == "scipy": raise NotImplementedError( - "Writing netCDF files with the %s backend " - "is not currently supported with dask's %s " - "scheduler" % (engine, scheduler) + f"Writing netCDF files with the {engine} backend " + f"is not currently supported with dask's {scheduler} scheduler" ) target = path_or_file if path_or_file is not None else BytesIO() @@ -1054,7 +1055,7 @@ def to_netcdf( kwargs["invalid_netcdf"] = invalid_netcdf else: raise ValueError( - "unrecognized option 'invalid_netcdf' for engine %s" % engine + f"unrecognized option 'invalid_netcdf' for engine {engine}" ) store = store_open(target, mode, format, group, **kwargs) @@ -1196,7 +1197,7 @@ def save_mfdataset( Data variables: a (time) float64 0.0 0.02128 0.04255 0.06383 ... 0.9574 0.9787 1.0 >>> years, datasets = zip(*ds.groupby("time.year")) - >>> paths = ["%s.nc" % y for y in years] + >>> paths = [f"{y}.nc" for y in years] >>> xr.save_mfdataset(datasets, paths) """ if mode == "w" and len(set(paths)) < len(paths): @@ -1208,7 +1209,7 @@ def save_mfdataset( if not isinstance(obj, Dataset): raise TypeError( "save_mfdataset only supports writing Dataset " - "objects, received type %s" % type(obj) + f"objects, received type {type(obj)}" ) if groups is None: @@ -1245,6 +1246,42 @@ def save_mfdataset( ) +def _validate_region(ds, region): + if not isinstance(region, dict): + raise TypeError(f"``region`` must be a dict, got {type(region)}") + + for k, v in region.items(): + if k not in ds.dims: + raise ValueError( + f"all keys in ``region`` are not in Dataset dimensions, got " + f"{list(region)} and {list(ds.dims)}" + ) + if not isinstance(v, slice): + raise TypeError( + "all values in ``region`` must be slice objects, got " + f"region={region}" + ) + if v.step not in {1, None}: + raise ValueError( + "step on all slices in ``region`` must be 1 or None, got " + f"region={region}" + ) + + non_matching_vars = [ + k for k, v in ds.variables.items() if not set(region).intersection(v.dims) + ] + if non_matching_vars: + raise ValueError( + f"when setting `region` explicitly in to_zarr(), all " + f"variables in the dataset to write must have at least " + f"one dimension in common with the region's dimensions " + f"{list(region.keys())}, but that is not " + f"the case for some variables here. To drop these variables " + f"from this dataset before exporting to zarr, write: " + f".drop({non_matching_vars!r})" + ) + + def _validate_datatypes_for_zarr_append(dataset): """DataArray.name and Dataset keys must be a string or None""" @@ -1269,98 +1306,6 @@ def check_dtype(var): check_dtype(k) -def _validate_append_dim_and_encoding( - ds_to_append, store, append_dim, region, encoding, **open_kwargs -): - try: - ds = backends.zarr.open_zarr(store, **open_kwargs) - except ValueError: # store empty - return - - if append_dim: - if append_dim not in ds.dims: - raise ValueError( - f"append_dim={append_dim!r} does not match any existing " - f"dataset dimensions {ds.dims}" - ) - if region is not None and append_dim in region: - raise ValueError( - f"cannot list the same dimension in both ``append_dim`` and " - f"``region`` with to_zarr(), got {append_dim} in both" - ) - - if region is not None: - if not isinstance(region, dict): - raise TypeError(f"``region`` must be a dict, got {type(region)}") - for k, v in region.items(): - if k not in ds_to_append.dims: - raise ValueError( - f"all keys in ``region`` are not in Dataset dimensions, got " - f"{list(region)} and {list(ds_to_append.dims)}" - ) - if not isinstance(v, slice): - raise TypeError( - "all values in ``region`` must be slice objects, got " - f"region={region}" - ) - if v.step not in {1, None}: - raise ValueError( - "step on all slices in ``region`` must be 1 or None, got " - f"region={region}" - ) - - non_matching_vars = [ - k - for k, v in ds_to_append.variables.items() - if not set(region).intersection(v.dims) - ] - if non_matching_vars: - raise ValueError( - f"when setting `region` explicitly in to_zarr(), all " - f"variables in the dataset to write must have at least " - f"one dimension in common with the region's dimensions " - f"{list(region.keys())}, but that is not " - f"the case for some variables here. To drop these variables " - f"from this dataset before exporting to zarr, write: " - f".drop({non_matching_vars!r})" - ) - - for var_name, new_var in ds_to_append.variables.items(): - if var_name in ds.variables: - existing_var = ds.variables[var_name] - if new_var.dims != existing_var.dims: - raise ValueError( - f"variable {var_name!r} already exists with different " - f"dimension names {existing_var.dims} != " - f"{new_var.dims}, but changing variable " - f"dimensions is not supported by to_zarr()." - ) - - existing_sizes = {} - for dim, size in existing_var.sizes.items(): - if region is not None and dim in region: - start, stop, stride = region[dim].indices(size) - assert stride == 1 # region was already validated above - size = stop - start - if dim != append_dim: - existing_sizes[dim] = size - - new_sizes = { - dim: size for dim, size in new_var.sizes.items() if dim != append_dim - } - if existing_sizes != new_sizes: - raise ValueError( - f"variable {var_name!r} already exists with different " - f"dimension sizes: {existing_sizes} != {new_sizes}. " - f"to_zarr() only supports changing dimension sizes when " - f"explicitly appending, but append_dim={append_dim!r}." - ) - if var_name in encoding.keys(): - raise ValueError( - f"variable {var_name!r} already exists, but encoding was provided" - ) - - def to_zarr( dataset: Dataset, store: Union[MutableMapping, str, Path] = None, @@ -1370,9 +1315,10 @@ def to_zarr( group: str = None, encoding: Mapping = None, compute: bool = True, - consolidated: bool = False, + consolidated: Optional[bool] = None, append_dim: Hashable = None, region: Mapping[str, slice] = None, + safe_chunks: bool = True, ): """This function creates an appropriate datastore for writing a dataset to a zarr ztore @@ -1388,57 +1334,81 @@ def to_zarr( encoding = {} if mode is None: - if append_dim is not None or region is not None: + if append_dim is not None: mode = "a" + elif region is not None: + mode = "r+" else: mode = "w-" if mode != "a" and append_dim is not None: raise ValueError("cannot set append_dim unless mode='a' or mode=None") - if mode != "a" and region is not None: - raise ValueError("cannot set region unless mode='a' or mode=None") + if mode not in ["a", "r+"] and region is not None: + raise ValueError("cannot set region unless mode='a', mode='r+' or mode=None") - if mode not in ["w", "w-", "a"]: - # TODO: figure out how to handle 'r+' + if mode not in ["w", "w-", "a", "r+"]: raise ValueError( "The only supported options for mode are 'w', " - f"'w-' and 'a', but mode={mode!r}" - ) - - if consolidated and region is not None: - raise ValueError( - "cannot use consolidated=True when the region argument is set. " - "Instead, set consolidated=True when writing to zarr with " - "compute=False before writing data." + f"'w-', 'a' and 'r+', but mode={mode!r}" ) # validate Dataset keys, DataArray names, and attr keys/values _validate_dataset_names(dataset) _validate_attrs(dataset) - if mode == "a": - _validate_datatypes_for_zarr_append(dataset) - _validate_append_dim_and_encoding( - dataset, - store, - append_dim, - group=group, - consolidated=consolidated, - region=region, - encoding=encoding, - ) + if region is not None: + _validate_region(dataset, region) + if append_dim is not None and append_dim in region: + raise ValueError( + f"cannot list the same dimension in both ``append_dim`` and " + f"``region`` with to_zarr(), got {append_dim} in both" + ) + if mode == "r+": + already_consolidated = consolidated + consolidate_on_close = False + else: + already_consolidated = False + consolidate_on_close = consolidated or consolidated is None zstore = backends.ZarrStore.open_group( store=store, mode=mode, synchronizer=synchronizer, group=group, - consolidate_on_close=consolidated, + consolidated=already_consolidated, + consolidate_on_close=consolidate_on_close, chunk_store=chunk_store, append_dim=append_dim, write_region=region, + safe_chunks=safe_chunks, + stacklevel=4, # for Dataset.to_zarr() ) + + if mode in ["a", "r+"]: + _validate_datatypes_for_zarr_append(dataset) + if append_dim is not None: + existing_dims = zstore.get_dimensions() + if append_dim not in existing_dims: + raise ValueError( + f"append_dim={append_dim!r} does not match any existing " + f"dataset dimensions {existing_dims}" + ) + existing_var_names = set(zstore.zarr_group.array_keys()) + for var_name in existing_var_names: + if var_name in encoding.keys(): + raise ValueError( + f"variable {var_name!r} already exists, but encoding was provided" + ) + if mode == "r+": + new_names = [k for k in dataset.variables if k not in existing_var_names] + if new_names: + raise ValueError( + f"dataset contains non-pre-existing variables {new_names}, " + "which is not allowed in ``xarray.Dataset.to_zarr()`` with " + "mode='r+'. To allow writing new variables, set mode='a'." + ) + writer = ArrayWriter() # TODO: figure out how to properly handle unlimited_dims dump_to_store(dataset, zstore, writer, encoding=encoding) diff --git a/xarray/backends/apiv2.py b/xarray/backends/apiv2.py deleted file mode 100644 index d31fc9ea773..00000000000 --- a/xarray/backends/apiv2.py +++ /dev/null @@ -1,282 +0,0 @@ -import os - -from ..core import indexing -from ..core.dataset import _get_chunk, _maybe_chunk -from ..core.utils import is_remote_uri -from . import plugins - - -def _protect_dataset_variables_inplace(dataset, cache): - for name, variable in dataset.variables.items(): - if name not in variable.dims: - # no need to protect IndexVariable objects - data = indexing.CopyOnWriteArray(variable._data) - if cache: - data = indexing.MemoryCachedArray(data) - variable.data = data - - -def _get_mtime(filename_or_obj): - # if passed an actual file path, augment the token with - # the file modification time - mtime = None - - try: - path = os.fspath(filename_or_obj) - except TypeError: - path = None - - if path and not is_remote_uri(path): - mtime = os.path.getmtime(filename_or_obj) - - return mtime - - -def _chunk_ds( - backend_ds, - filename_or_obj, - engine, - chunks, - overwrite_encoded_chunks, - **extra_tokens, -): - from dask.base import tokenize - - mtime = _get_mtime(filename_or_obj) - token = tokenize(filename_or_obj, mtime, engine, chunks, **extra_tokens) - name_prefix = "open_dataset-%s" % token - - variables = {} - for name, var in backend_ds.variables.items(): - var_chunks = _get_chunk(var, chunks) - variables[name] = _maybe_chunk( - name, - var, - var_chunks, - overwrite_encoded_chunks=overwrite_encoded_chunks, - name_prefix=name_prefix, - token=token, - ) - ds = backend_ds._replace(variables) - return ds - - -def _dataset_from_backend_dataset( - backend_ds, - filename_or_obj, - engine, - chunks, - cache, - overwrite_encoded_chunks, - **extra_tokens, -): - if not (isinstance(chunks, (int, dict)) or chunks is None): - if chunks != "auto": - raise ValueError( - "chunks must be an int, dict, 'auto', or None. " - "Instead found %s. " % chunks - ) - - _protect_dataset_variables_inplace(backend_ds, cache) - if chunks is None: - ds = backend_ds - else: - ds = _chunk_ds( - backend_ds, - filename_or_obj, - engine, - chunks, - overwrite_encoded_chunks, - **extra_tokens, - ) - - ds.set_close(backend_ds._close) - - # Ensure source filename always stored in dataset object (GH issue #2550) - if "source" not in ds.encoding: - if isinstance(filename_or_obj, str): - ds.encoding["source"] = filename_or_obj - - return ds - - -def _resolve_decoders_kwargs(decode_cf, open_backend_dataset_parameters, **decoders): - for d in list(decoders): - if decode_cf is False and d in open_backend_dataset_parameters: - decoders[d] = False - if decoders[d] is None: - decoders.pop(d) - return decoders - - -def open_dataset( - filename_or_obj, - *, - engine=None, - chunks=None, - cache=None, - decode_cf=None, - mask_and_scale=None, - decode_times=None, - decode_timedelta=None, - use_cftime=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - backend_kwargs=None, - **kwargs, -): - """Open and decode a dataset from a file or file-like object. - - Parameters - ---------- - filename_or_obj : str, Path, file-like or DataStore - Strings and Path objects are interpreted as a path to a netCDF file - or an OpenDAP URL and opened with python-netCDF4, unless the filename - ends with .gz, in which case the file is unzipped and opened with - scipy.io.netcdf (only netCDF3 supported). Byte-strings or file-like - objects are opened by scipy.io.netcdf (netCDF3) or h5py (netCDF4/HDF). - engine : str, optional - Engine to use when reading files. If not provided, the default engine - is chosen based on available dependencies, with a preference for - "netcdf4". Options are: {"netcdf4", "scipy", "pydap", "h5netcdf",\ - "pynio", "cfgrib", "pseudonetcdf", "zarr"}. - chunks : int or dict, optional - If chunks is provided, it is used to load the new dataset into dask - arrays. ``chunks=-1`` loads the dataset with dask using a single - chunk for all arrays. `chunks={}`` loads the dataset with dask using - engine preferred chunks if exposed by the backend, otherwise with - a single chunk for all arrays. - ``chunks='auto'`` will use dask ``auto`` chunking taking into account the - engine preferred chunks. See dask chunking for more details. - cache : bool, optional - If True, cache data is loaded from the underlying datastore in memory as - NumPy arrays when accessed to avoid reading from the underlying data- - store multiple times. Defaults to True unless you specify the `chunks` - argument to use dask, in which case it defaults to False. Does not - change the behavior of coordinates corresponding to dimensions, which - always load their data from disk into a ``pandas.Index``. - decode_cf : bool, optional - Setting ``decode_cf=False`` will disable ``mask_and_scale``, - ``decode_times``, ``decode_timedelta``, ``concat_characters``, - ``decode_coords``. - mask_and_scale : bool, optional - If True, array values equal to `_FillValue` are replaced with NA and other - values are scaled according to the formula `original_values * scale_factor + - add_offset`, where `_FillValue`, `scale_factor` and `add_offset` are - taken from variable attributes (if they exist). If the `_FillValue` or - `missing_value` attribute contains multiple values, a warning will be - issued and all array values matching one of the multiple values will - be replaced by NA. mask_and_scale defaults to True except for the - pseudonetcdf backend. This keyword may not be supported by all the backends. - decode_times : bool, optional - If True, decode times encoded in the standard NetCDF datetime format - into datetime objects. Otherwise, leave them encoded as numbers. - This keyword may not be supported by all the backends. - decode_timedelta : bool, optional - If True, decode variables and coordinates with time units in - {"days", "hours", "minutes", "seconds", "milliseconds", "microseconds"} - into timedelta objects. If False, they remain encoded as numbers. - If None (default), assume the same value of decode_time. - This keyword may not be supported by all the backends. - use_cftime: bool, optional - Only relevant if encoded dates come from a standard calendar - (e.g. "gregorian", "proleptic_gregorian", "standard", or not - specified). If None (default), attempt to decode times to - ``np.datetime64[ns]`` objects; if this is not possible, decode times to - ``cftime.datetime`` objects. If True, always decode times to - ``cftime.datetime`` objects, regardless of whether or not they can be - represented using ``np.datetime64[ns]`` objects. If False, always - decode times to ``np.datetime64[ns]`` objects; if this is not possible - raise an error. This keyword may not be supported by all the backends. - concat_characters : bool, optional - If True, concatenate along the last dimension of character arrays to - form string arrays. Dimensions will only be concatenated over (and - removed) if they have no corresponding variable and if they are only - used as the last dimension of character arrays. - This keyword may not be supported by all the backends. - decode_coords : bool, optional - If True, decode the 'coordinates' attribute to identify coordinates in - the resulting dataset. This keyword may not be supported by all the - backends. - drop_variables: str or iterable, optional - A variable or list of variables to exclude from the dataset parsing. - This may be useful to drop variables with problems or - inconsistent values. - backend_kwargs: - Additional keyword arguments passed on to the engine open function. - **kwargs: dict - Additional keyword arguments passed on to the engine open function. - For example: - - - 'group': path to the netCDF4 group in the given file to open given as - a str,supported by "netcdf4", "h5netcdf", "zarr". - - - 'lock': resource lock to use when reading data from disk. Only - relevant when using dask or another form of parallelism. By default, - appropriate locks are chosen to safely read and write files with the - currently active dask scheduler. Supported by "netcdf4", "h5netcdf", - "pynio", "pseudonetcdf", "cfgrib". - - See engine open function for kwargs accepted by each specific engine. - - - Returns - ------- - dataset : Dataset - The newly created dataset. - - Notes - ----- - ``open_dataset`` opens the file with read-only access. When you modify - values of a Dataset, even one linked to files on disk, only the in-memory - copy you are manipulating in xarray is modified: the original file on disk - is never touched. - - See Also - -------- - open_mfdataset - """ - - if cache is None: - cache = chunks is None - - if backend_kwargs is not None: - kwargs.update(backend_kwargs) - - if engine is None: - engine = plugins.guess_engine(filename_or_obj) - - backend = plugins.get_backend(engine) - - decoders = _resolve_decoders_kwargs( - decode_cf, - open_backend_dataset_parameters=backend.open_dataset_parameters, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - decode_timedelta=decode_timedelta, - concat_characters=concat_characters, - use_cftime=use_cftime, - decode_coords=decode_coords, - ) - - overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None) - backend_ds = backend.open_dataset( - filename_or_obj, - drop_variables=drop_variables, - **decoders, - **kwargs, - ) - ds = _dataset_from_backend_dataset( - backend_ds, - filename_or_obj, - engine, - chunks, - cache, - overwrite_encoded_chunks, - drop_variables=drop_variables, - **decoders, - **kwargs, - ) - - return ds diff --git a/xarray/backends/cfgrib_.py b/xarray/backends/cfgrib_.py index 65c5bc2a02b..e7aeaaba83a 100644 --- a/xarray/backends/cfgrib_.py +++ b/xarray/backends/cfgrib_.py @@ -1,4 +1,5 @@ import os +import warnings import numpy as np @@ -10,6 +11,7 @@ AbstractDataStore, BackendArray, BackendEntrypoint, + _normalize_path, ) from .locks import SerializableLock, ensure_lock from .store import StoreBackendEntrypoint @@ -20,7 +22,13 @@ has_cfgrib = True except ModuleNotFoundError: has_cfgrib = False - +# cfgrib throws a RuntimeError if eccodes is not installed +except (ImportError, RuntimeError): + warnings.warn( + "Failed to load cfgrib - most likely there is a problem accessing the ecCodes library. " + "Try `import cfgrib` to get the full error message" + ) + has_cfgrib = False # FIXME: Add a dedicated lock, even if ecCodes is supposed to be thread-safe # in most circumstances. See: @@ -62,7 +70,7 @@ def open_store_variable(self, name, var): data = var.data else: wrapped_array = CfGribArrayWrapper(self, var.data) - data = indexing.LazilyOuterIndexedArray(wrapped_array) + data = indexing.LazilyIndexedArray(wrapped_array) encoding = self.ds.encoding.copy() encoding["original_shape"] = var.data.shape @@ -82,14 +90,15 @@ def get_dimensions(self): def get_encoding(self): dims = self.get_dimensions() - encoding = {"unlimited_dims": {k for k, v in dims.items() if v is None}} - return encoding + return {"unlimited_dims": {k for k, v in dims.items() if v is None}} class CfgribfBackendEntrypoint(BackendEntrypoint): - def guess_can_open(self, store_spec): + available = has_cfgrib + + def guess_can_open(self, filename_or_obj): try: - _, ext = os.path.splitext(store_spec) + _, ext = os.path.splitext(filename_or_obj) except TypeError: return False return ext in {".grib", ".grib2", ".grb", ".grb2"} @@ -99,9 +108,9 @@ def open_dataset( filename_or_obj, *, mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, + decode_times=True, + concat_characters=True, + decode_coords=True, drop_variables=None, use_cftime=None, decode_timedelta=None, @@ -114,6 +123,7 @@ def open_dataset( time_dims=("time", "step"), ): + filename_or_obj = _normalize_path(filename_or_obj) store = CfGribDataStore( filename_or_obj, indexpath=indexpath, @@ -139,5 +149,4 @@ def open_dataset( return ds -if has_cfgrib: - BACKEND_ENTRYPOINTS["cfgrib"] = CfgribfBackendEntrypoint +BACKEND_ENTRYPOINTS["cfgrib"] = CfgribfBackendEntrypoint diff --git a/xarray/backends/common.py b/xarray/backends/common.py index e2905d0866b..64a245ddead 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -1,14 +1,16 @@ import logging +import os.path import time import traceback -from typing import Dict, Tuple, Type, Union +from pathlib import Path +from typing import Any, Dict, Tuple, Type, Union import numpy as np from ..conventions import cf_encoder from ..core import indexing from ..core.pycompat import is_duck_dask_array -from ..core.utils import FrozenDict, NdimSizeLenMixin +from ..core.utils import FrozenDict, NdimSizeLenMixin, is_remote_uri # Create a logger object, but don't add any handlers. Leave that to user code. logger = logging.getLogger(__name__) @@ -17,6 +19,16 @@ NONE_VAR_NAME = "__values__" +def _normalize_path(path): + if isinstance(path, Path): + path = str(path) + + if isinstance(path, str) and not is_remote_uri(path): + path = os.path.abspath(os.path.expanduser(path)) + + return path + + def _encode_variable_name(name): if name is None: name = NONE_VAR_NAME @@ -57,9 +69,8 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500 base_delay = initial_delay * 2 ** n next_delay = base_delay + np.random.randint(base_delay) msg = ( - "getitem failed, waiting %s ms before trying again " - "(%s tries remaining). Full traceback: %s" - % (next_delay, max_retries - n, traceback.format_exc()) + f"getitem failed, waiting {next_delay} ms before trying again " + f"({max_retries - n} tries remaining). Full traceback: {traceback.format_exc()}" ) logger.debug(msg) time.sleep(1e-3 * next_delay) @@ -324,7 +335,7 @@ def set_dimensions(self, variables, unlimited_dims=None): if dim in existing_dims and length != existing_dims[dim]: raise ValueError( "Unable to update size for existing dimension" - "%r (%d != %d)" % (dim, length, existing_dims[dim]) + f"{dim!r} ({length} != {existing_dims[dim]})" ) elif dim not in existing_dims: is_unlimited = dim in unlimited_dims @@ -344,12 +355,41 @@ def encode(self, variables, attributes): class BackendEntrypoint: + """ + ``BackendEntrypoint`` is a class container and it is the main interface + for the backend plugins, see :ref:`RST backend_entrypoint`. + It shall implement: + + - ``open_dataset`` method: it shall implement reading from file, variables + decoding and it returns an instance of :py:class:`~xarray.Dataset`. + It shall take in input at least ``filename_or_obj`` argument and + ``drop_variables`` keyword argument. + For more details see :ref:`RST open_dataset`. + - ``guess_can_open`` method: it shall return ``True`` if the backend is able to open + ``filename_or_obj``, ``False`` otherwise. The implementation of this + method is not mandatory. + """ + open_dataset_parameters: Union[Tuple, None] = None + """list of ``open_dataset`` method parameters""" + + def open_dataset( + self, + filename_or_obj: str, + drop_variables: Tuple[str] = None, + **kwargs: Any, + ): + """ + Backend open_dataset method used by Xarray in :py:func:`~xarray.open_dataset`. + """ - def open_dataset(self): raise NotImplementedError - def guess_can_open(self, store_spec): + def guess_can_open(self, filename_or_obj): + """ + Backend open_dataset method used by Xarray in :py:func:`~xarray.open_dataset`. + """ + return False diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index aa892c4f89c..3a49928ec65 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -6,12 +6,18 @@ import numpy as np from ..core import indexing -from ..core.utils import FrozenDict, is_remote_uri, read_magic_number +from ..core.utils import ( + FrozenDict, + is_remote_uri, + read_magic_number_from_file, + try_read_magic_number_from_file_or_path, +) from ..core.variable import Variable from .common import ( BACKEND_ENTRYPOINTS, BackendEntrypoint, WritableCFDataStore, + _normalize_path, find_root_and_group, ) from .file_manager import CachingFileManager, DummyFileManager @@ -36,8 +42,7 @@ class H5NetCDFArrayWrapper(BaseNetCDF4Array): def get_array(self, needs_lock=True): ds = self.datastore._acquire(needs_lock) - variable = ds.variables[self.variable_name] - return variable + return ds.variables[self.variable_name] def __getitem__(self, key): return indexing.explicit_indexing_adapter( @@ -101,7 +106,7 @@ def __init__(self, manager, group=None, mode=None, lock=HDF5_LOCK, autoclose=Fal if group is None: root, group = find_root_and_group(manager) else: - if not type(manager) is h5netcdf.File: + if type(manager) is not h5netcdf.File: raise ValueError( "must supply a h5netcdf.File if the group " "argument is provided" @@ -131,6 +136,7 @@ def open( autoclose=False, invalid_netcdf=None, phony_dims=None, + decode_vlen_strings=True, ): if isinstance(filename, bytes): @@ -139,10 +145,10 @@ def open( "try passing a path or file-like object" ) elif isinstance(filename, io.IOBase): - magic_number = read_magic_number(filename) + magic_number = read_magic_number_from_file(filename) if not magic_number.startswith(b"\211HDF\r\n\032\n"): raise ValueError( - f"{magic_number} is not the signature of a valid netCDF file" + f"{magic_number} is not the signature of a valid netCDF4 file" ) if format not in [None, "NETCDF4"]: @@ -157,6 +163,10 @@ def open( "h5netcdf backend keyword argument 'phony_dims' needs " "h5netcdf >= 0.8.0." ) + if LooseVersion(h5netcdf.__version__) >= LooseVersion( + "0.10.0" + ) and LooseVersion(h5netcdf.core.h5py.__version__) >= LooseVersion("3.0.0"): + kwargs["decode_vlen_strings"] = decode_vlen_strings if lock is None: if mode == "r": @@ -182,7 +192,7 @@ def open_store_variable(self, name, var): import h5py dimensions = var.dimensions - data = indexing.LazilyOuterIndexedArray(H5NetCDFArrayWrapper(name, self)) + data = indexing.LazilyIndexedArray(H5NetCDFArrayWrapper(name, self)) attrs = _read_attributes(var) # netCDF4 specific encoding @@ -227,11 +237,9 @@ def get_dimensions(self): return self.ds.dimensions def get_encoding(self): - encoding = {} - encoding["unlimited_dims"] = { - k for k, v in self.ds.dimensions.items() if v is None + return { + "unlimited_dims": {k for k, v in self.ds.dimensions.items() if v is None} } - return encoding def set_dimension(self, name, length, is_unlimited=False): if is_unlimited: @@ -260,9 +268,9 @@ def prepare_variable( "h5netcdf does not yet support setting a fill value for " "variable-length strings " "(https://github.com/shoyer/h5netcdf/issues/37). " - "Either remove '_FillValue' from encoding on variable %r " + f"Either remove '_FillValue' from encoding on variable {name!r} " "or set {'dtype': 'S1'} in encoding to use the fixed width " - "NC_CHAR type." % name + "NC_CHAR type." ) if dtype is str: @@ -329,14 +337,15 @@ def close(self, **kwargs): class H5netcdfBackendEntrypoint(BackendEntrypoint): - def guess_can_open(self, store_spec): - try: - return read_magic_number(store_spec).startswith(b"\211HDF\r\n\032\n") - except TypeError: - pass + available = has_h5netcdf + + def guess_can_open(self, filename_or_obj): + magic_number = try_read_magic_number_from_file_or_path(filename_or_obj) + if magic_number is not None: + return magic_number.startswith(b"\211HDF\r\n\032\n") try: - _, ext = os.path.splitext(store_spec) + _, ext = os.path.splitext(filename_or_obj) except TypeError: return False @@ -347,9 +356,9 @@ def open_dataset( filename_or_obj, *, mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, + decode_times=True, + concat_characters=True, + decode_coords=True, drop_variables=None, use_cftime=None, decode_timedelta=None, @@ -358,8 +367,10 @@ def open_dataset( lock=None, invalid_netcdf=None, phony_dims=None, + decode_vlen_strings=True, ): + filename_or_obj = _normalize_path(filename_or_obj) store = H5NetCDFStore.open( filename_or_obj, format=format, @@ -367,6 +378,7 @@ def open_dataset( lock=lock, invalid_netcdf=invalid_netcdf, phony_dims=phony_dims, + decode_vlen_strings=decode_vlen_strings, ) store_entrypoint = StoreBackendEntrypoint() @@ -384,5 +396,4 @@ def open_dataset( return ds -if has_h5netcdf: - BACKEND_ENTRYPOINTS["h5netcdf"] = H5netcdfBackendEntrypoint +BACKEND_ENTRYPOINTS["h5netcdf"] = H5netcdfBackendEntrypoint diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py index bb876a432c8..59417336f5f 100644 --- a/xarray/backends/locks.py +++ b/xarray/backends/locks.py @@ -67,7 +67,7 @@ def _get_scheduler(get=None, collection=None) -> Optional[str]: None is returned if no dask scheduler is active. - See also + See Also -------- dask.base.get_scheduler """ @@ -167,7 +167,7 @@ def locked(self): return any(lock.locked for lock in self.locks) def __repr__(self): - return "CombinedLock(%r)" % list(self.locks) + return f"CombinedLock({list(self.locks)!r})" class DummyLock: diff --git a/xarray/backends/lru_cache.py b/xarray/backends/lru_cache.py index 5ca49a0311a..48030903036 100644 --- a/xarray/backends/lru_cache.py +++ b/xarray/backends/lru_cache.py @@ -34,7 +34,7 @@ def __init__(self, maxsize: int, on_evict: Callable[[K, V], Any] = None): ---------- maxsize : int Integer maximum number of items to hold in the cache. - on_evict: callable, optional + on_evict : callable, optional Function to call like ``on_evict(key, value)`` when items are evicted. """ diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index e3d87aaf83f..769c96c99ce 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -9,13 +9,19 @@ from .. import coding from ..coding.variables import pop_to from ..core import indexing -from ..core.utils import FrozenDict, close_on_error, is_remote_uri +from ..core.utils import ( + FrozenDict, + close_on_error, + is_remote_uri, + try_read_magic_number_from_path, +) from ..core.variable import Variable from .common import ( BACKEND_ENTRYPOINTS, BackendArray, BackendEntrypoint, WritableCFDataStore, + _normalize_path, find_root_and_group, robust_getitem, ) @@ -121,25 +127,23 @@ def _encode_nc4_variable(var): def _check_encoding_dtype_is_vlen_string(dtype): if dtype is not str: raise AssertionError( # pragma: no cover - "unexpected dtype encoding %r. This shouldn't happen: please " - "file a bug report at github.com/pydata/xarray" % dtype + f"unexpected dtype encoding {dtype!r}. This shouldn't happen: please " + "file a bug report at github.com/pydata/xarray" ) def _get_datatype(var, nc_format="NETCDF4", raise_on_invalid_encoding=False): if nc_format == "NETCDF4": - datatype = _nc4_dtype(var) - else: - if "dtype" in var.encoding: - encoded_dtype = var.encoding["dtype"] - _check_encoding_dtype_is_vlen_string(encoded_dtype) - if raise_on_invalid_encoding: - raise ValueError( - "encoding dtype=str for vlen strings is only supported " - "with format='NETCDF4'." - ) - datatype = var.dtype - return datatype + return _nc4_dtype(var) + if "dtype" in var.encoding: + encoded_dtype = var.encoding["dtype"] + _check_encoding_dtype_is_vlen_string(encoded_dtype) + if raise_on_invalid_encoding: + raise ValueError( + "encoding dtype=str for vlen strings is only supported " + "with format='NETCDF4'." + ) + return var.dtype def _nc4_dtype(var): @@ -177,7 +181,7 @@ def _nc4_require_group(ds, group, mode, create_group=_netcdf4_create_group): ds = create_group(ds, key) else: # wrap error to provide slightly more helpful message - raise OSError("group not found: %s" % key, e) + raise OSError(f"group not found: {key}", e) return ds @@ -202,7 +206,7 @@ def _force_native_endianness(var): # if endian exists, remove it from the encoding. var.encoding.pop("endian", None) # check to see if encoding has a value for endian its 'native' - if not var.encoding.get("endian", "native") == "native": + if var.encoding.get("endian", "native") != "native": raise NotImplementedError( "Attempt to write non-native endian type, " "this is not supported by the netCDF4 " @@ -269,8 +273,8 @@ def _extract_nc4_variable_encoding( invalid = [k for k in encoding if k not in valid_encodings] if invalid: raise ValueError( - "unexpected encoding parameters for %r backend: %r. Valid " - "encodings are: %r" % (backend, invalid, valid_encodings) + f"unexpected encoding parameters for {backend!r} backend: {invalid!r}. Valid " + f"encodings are: {valid_encodings!r}" ) else: for k in list(encoding): @@ -281,10 +285,8 @@ def _extract_nc4_variable_encoding( def _is_list_of_strings(value): - if np.asarray(value).dtype.kind in ["U", "S"] and np.asarray(value).size > 1: - return True - else: - return False + arr = np.asarray(value) + return arr.dtype.kind in ["U", "S"] and arr.size > 1 class NetCDF4DataStore(WritableCFDataStore): @@ -312,7 +314,7 @@ def __init__( if group is None: root, group = find_root_and_group(manager) else: - if not type(manager) is netCDF4.Dataset: + if type(manager) is not netCDF4.Dataset: raise ValueError( "must supply a root netCDF4.Dataset if the group " "argument is provided" @@ -388,7 +390,7 @@ def ds(self): def open_store_variable(self, name, var): dimensions = var.dimensions - data = indexing.LazilyOuterIndexedArray(NetCDF4ArrayWrapper(name, self)) + data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self)) attributes = {k: var.getncattr(k) for k in var.ncattrs()} _ensure_fill_value_valid(data, attributes) # netCDF4 specific encoding; save _FillValue for later @@ -416,25 +418,22 @@ def open_store_variable(self, name, var): return Variable(dimensions, data, attributes, encoding) def get_variables(self): - dsvars = FrozenDict( + return FrozenDict( (k, self.open_store_variable(k, v)) for k, v in self.ds.variables.items() ) - return dsvars def get_attrs(self): - attrs = FrozenDict((k, self.ds.getncattr(k)) for k in self.ds.ncattrs()) - return attrs + return FrozenDict((k, self.ds.getncattr(k)) for k in self.ds.ncattrs()) def get_dimensions(self): - dims = FrozenDict((k, len(v)) for k, v in self.ds.dimensions.items()) - return dims + return FrozenDict((k, len(v)) for k, v in self.ds.dimensions.items()) def get_encoding(self): - encoding = {} - encoding["unlimited_dims"] = { - k for k, v in self.ds.dimensions.items() if v.isunlimited() + return { + "unlimited_dims": { + k for k, v in self.ds.dimensions.items() if v.isunlimited() + } } - return encoding def set_dimension(self, name, length, is_unlimited=False): dim_length = length if not is_unlimited else None @@ -472,9 +471,9 @@ def prepare_variable( "netCDF4 does not yet support setting a fill value for " "variable-length strings " "(https://github.com/Unidata/netcdf4-python/issues/730). " - "Either remove '_FillValue' from encoding on variable %r " + f"Either remove '_FillValue' from encoding on variable {name!r} " "or set {'dtype': 'S1'} in encoding to use the fixed width " - "NC_CHAR type." % name + "NC_CHAR type." ) encoding = _extract_nc4_variable_encoding( @@ -513,11 +512,17 @@ def close(self, **kwargs): class NetCDF4BackendEntrypoint(BackendEntrypoint): - def guess_can_open(self, store_spec): - if isinstance(store_spec, str) and is_remote_uri(store_spec): + available = has_netcdf4 + + def guess_can_open(self, filename_or_obj): + if isinstance(filename_or_obj, str) and is_remote_uri(filename_or_obj): return True + magic_number = try_read_magic_number_from_path(filename_or_obj) + if magic_number is not None: + # netcdf 3 or HDF5 + return magic_number.startswith((b"CDF", b"\211HDF\r\n\032\n")) try: - _, ext = os.path.splitext(store_spec) + _, ext = os.path.splitext(filename_or_obj) except TypeError: return False return ext in {".nc", ".nc4", ".cdf"} @@ -526,9 +531,9 @@ def open_dataset( self, filename_or_obj, mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, + decode_times=True, + concat_characters=True, + decode_coords=True, drop_variables=None, use_cftime=None, decode_timedelta=None, @@ -542,6 +547,7 @@ def open_dataset( autoclose=False, ): + filename_or_obj = _normalize_path(filename_or_obj) store = NetCDF4DataStore.open( filename_or_obj, mode=mode, @@ -569,5 +575,4 @@ def open_dataset( return ds -if has_netcdf4: - BACKEND_ENTRYPOINTS["netcdf4"] = NetCDF4BackendEntrypoint +BACKEND_ENTRYPOINTS["netcdf4"] = NetCDF4BackendEntrypoint diff --git a/xarray/backends/netcdf3.py b/xarray/backends/netcdf3.py index 001af0bf8e1..5fdd0534d57 100644 --- a/xarray/backends/netcdf3.py +++ b/xarray/backends/netcdf3.py @@ -125,8 +125,6 @@ def is_valid_nc3_name(s): """ if not isinstance(s, str): return False - if not isinstance(s, str): - s = s.decode("utf-8") num_bytes = len(s.encode("utf-8")) return ( (unicodedata.normalize("NFC", s) == s) diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py index b8cd2bf6378..08c1bec8325 100644 --- a/xarray/backends/plugins.py +++ b/xarray/backends/plugins.py @@ -1,26 +1,25 @@ import functools import inspect import itertools -import logging import warnings import pkg_resources -from .common import BACKEND_ENTRYPOINTS +from .common import BACKEND_ENTRYPOINTS, BackendEntrypoint +STANDARD_BACKENDS_ORDER = ["netcdf4", "h5netcdf", "scipy"] -def remove_duplicates(backend_entrypoints): + +def remove_duplicates(pkg_entrypoints): # sort and group entrypoints by name - backend_entrypoints = sorted(backend_entrypoints, key=lambda ep: ep.name) - backend_entrypoints_grouped = itertools.groupby( - backend_entrypoints, key=lambda ep: ep.name - ) + pkg_entrypoints = sorted(pkg_entrypoints, key=lambda ep: ep.name) + pkg_entrypoints_grouped = itertools.groupby(pkg_entrypoints, key=lambda ep: ep.name) # check if there are multiple entrypoints for the same name - unique_backend_entrypoints = [] - for name, matches in backend_entrypoints_grouped: + unique_pkg_entrypoints = [] + for name, matches in pkg_entrypoints_grouped: matches = list(matches) - unique_backend_entrypoints.append(matches[0]) + unique_pkg_entrypoints.append(matches[0]) matches_len = len(matches) if matches_len > 1: selected_module_name = matches[0].module_name @@ -30,7 +29,7 @@ def remove_duplicates(backend_entrypoints): f"\n {all_module_names}.\n It will be used: {selected_module_name}.", RuntimeWarning, ) - return unique_backend_entrypoints + return unique_pkg_entrypoints def detect_parameters(open_dataset): @@ -51,13 +50,16 @@ def detect_parameters(open_dataset): return tuple(parameters_list) -def create_engines_dict(backend_entrypoints): - engines = {} - for backend_ep in backend_entrypoints: - name = backend_ep.name - backend = backend_ep.load() - engines[name] = backend - return engines +def backends_dict_from_pkg(pkg_entrypoints): + backend_entrypoints = {} + for pkg_ep in pkg_entrypoints: + name = pkg_ep.name + try: + backend = pkg_ep.load() + backend_entrypoints[name] = backend + except Exception as ex: + warnings.warn(f"Engine {name!r} loading failed:\n{ex}", RuntimeWarning) + return backend_entrypoints def set_missing_parameters(backend_entrypoints): @@ -67,22 +69,34 @@ def set_missing_parameters(backend_entrypoints): backend.open_dataset_parameters = detect_parameters(open_dataset) -def build_engines(entrypoints): - backend_entrypoints = BACKEND_ENTRYPOINTS.copy() - pkg_entrypoints = remove_duplicates(entrypoints) - external_backend_entrypoints = create_engines_dict(pkg_entrypoints) +def sort_backends(backend_entrypoints): + ordered_backends_entrypoints = {} + for be_name in STANDARD_BACKENDS_ORDER: + if backend_entrypoints.get(be_name, None) is not None: + ordered_backends_entrypoints[be_name] = backend_entrypoints.pop(be_name) + ordered_backends_entrypoints.update( + {name: backend_entrypoints[name] for name in sorted(backend_entrypoints)} + ) + return ordered_backends_entrypoints + + +def build_engines(pkg_entrypoints): + backend_entrypoints = {} + for backend_name, backend in BACKEND_ENTRYPOINTS.items(): + if backend.available: + backend_entrypoints[backend_name] = backend + pkg_entrypoints = remove_duplicates(pkg_entrypoints) + external_backend_entrypoints = backends_dict_from_pkg(pkg_entrypoints) backend_entrypoints.update(external_backend_entrypoints) + backend_entrypoints = sort_backends(backend_entrypoints) set_missing_parameters(backend_entrypoints) - engines = {} - for name, backend in backend_entrypoints.items(): - engines[name] = backend() - return engines + return {name: backend() for name, backend in backend_entrypoints.items()} @functools.lru_cache(maxsize=1) def list_engines(): - entrypoints = pkg_resources.iter_entry_points("xarray.backends") - return build_engines(entrypoints) + pkg_entrypoints = pkg_resources.iter_entry_points("xarray.backends") + return build_engines(pkg_entrypoints) def guess_engine(store_spec): @@ -90,19 +104,67 @@ def guess_engine(store_spec): for engine, backend in engines.items(): try: - if backend.guess_can_open and backend.guess_can_open(store_spec): + if backend.guess_can_open(store_spec): return engine except Exception: - logging.exception(f"{engine!r} fails while guessing") + warnings.warn(f"{engine!r} fails while guessing", RuntimeWarning) - raise ValueError("cannot guess the engine, try passing one explicitly") + compatible_engines = [] + for engine, backend_cls in BACKEND_ENTRYPOINTS.items(): + try: + backend = backend_cls() + if backend.guess_can_open(store_spec): + compatible_engines.append(engine) + except Exception: + warnings.warn(f"{engine!r} fails while guessing", RuntimeWarning) + + installed_engines = [k for k in engines if k != "store"] + if not compatible_engines: + if installed_engines: + error_msg = ( + "did not find a match in any of xarray's currently installed IO " + f"backends {installed_engines}. Consider explicitly selecting one of the " + "installed engines via the ``engine`` parameter, or installing " + "additional IO dependencies, see:\n" + "http://xarray.pydata.org/en/stable/getting-started-guide/installing.html\n" + "http://xarray.pydata.org/en/stable/user-guide/io.html" + ) + else: + error_msg = ( + "xarray is unable to open this file because it has no currently " + "installed IO backends. Xarray's read/write support requires " + "installing optional IO dependencies, see:\n" + "http://xarray.pydata.org/en/stable/getting-started-guide/installing.html\n" + "http://xarray.pydata.org/en/stable/user-guide/io" + ) + else: + error_msg = ( + "found the following matches with the input file in xarray's IO " + f"backends: {compatible_engines}. But their dependencies may not be installed, see:\n" + "http://xarray.pydata.org/en/stable/user-guide/io.html \n" + "http://xarray.pydata.org/en/stable/getting-started-guide/installing.html" + ) + + raise ValueError(error_msg) def get_backend(engine): - """Select open_dataset method based on current engine""" - engines = list_engines() - if engine not in engines: - raise ValueError( - f"unrecognized engine {engine} must be one of: {list(engines)}" + """Select open_dataset method based on current engine.""" + if isinstance(engine, str): + engines = list_engines() + if engine not in engines: + raise ValueError( + f"unrecognized engine {engine} must be one of: {list(engines)}" + ) + backend = engines[engine] + elif isinstance(engine, type) and issubclass(engine, BackendEntrypoint): + backend = engine + else: + raise TypeError( + ( + "engine must be a string or a subclass of " + f"xarray.backends.BackendEntrypoint: {engine}" + ) ) - return engines[engine] + + return backend diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index 80485fce459..da178926dbe 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -8,6 +8,7 @@ AbstractDataStore, BackendArray, BackendEntrypoint, + _normalize_path, ) from .file_manager import CachingFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock @@ -74,7 +75,7 @@ def ds(self): return self._manager.acquire() def open_store_variable(self, name, var): - data = indexing.LazilyOuterIndexedArray(PncArrayWrapper(name, self)) + data = indexing.LazilyIndexedArray(PncArrayWrapper(name, self)) attrs = {k: getattr(var, k) for k in var.ncattrs()} return Variable(var.dimensions, data, attrs) @@ -101,6 +102,7 @@ def close(self): class PseudoNetCDFBackendEntrypoint(BackendEntrypoint): + available = has_pseudonetcdf # *args and **kwargs are not allowed in open_backend_dataset_ kwargs, # unless the open_dataset_parameters are explicity defined like this: @@ -121,9 +123,9 @@ def open_dataset( self, filename_or_obj, mask_and_scale=False, - decode_times=None, - concat_characters=None, - decode_coords=None, + decode_times=True, + concat_characters=True, + decode_coords=True, drop_variables=None, use_cftime=None, decode_timedelta=None, @@ -131,6 +133,8 @@ def open_dataset( lock=None, **format_kwargs, ): + + filename_or_obj = _normalize_path(filename_or_obj) store = PseudoNetCDFDataStore.open( filename_or_obj, lock=lock, mode=mode, **format_kwargs ) @@ -150,5 +154,4 @@ def open_dataset( return ds -if has_pseudonetcdf: - BACKEND_ENTRYPOINTS["pseudonetcdf"] = PseudoNetCDFBackendEntrypoint +BACKEND_ENTRYPOINTS["pseudonetcdf"] = PseudoNetCDFBackendEntrypoint diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index 7f8622ca66e..bc479f9a71d 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np from ..core import indexing @@ -45,7 +47,7 @@ def _getitem(self, key): result = robust_getitem(array, key, catch=ValueError) # in some cases, pydap doesn't squeeze axes automatically like numpy axis = tuple(n for n, k in enumerate(key) if isinstance(k, integer_types)) - if result.ndim + len(axis) != array.ndim and len(axis) > 0: + if result.ndim + len(axis) != array.ndim and axis: result = np.squeeze(result, axis) return result @@ -92,7 +94,7 @@ def open(cls, url, session=None): return cls(ds) def open_store_variable(self, var): - data = indexing.LazilyOuterIndexedArray(PydapArrayWrapper(var)) + data = indexing.LazilyIndexedArray(PydapArrayWrapper(var)) return Variable(var.dimensions, data, _fix_attributes(var.attributes)) def get_variables(self): @@ -108,21 +110,32 @@ def get_dimensions(self): class PydapBackendEntrypoint(BackendEntrypoint): - def guess_can_open(self, store_spec): - return isinstance(store_spec, str) and is_remote_uri(store_spec) + available = has_pydap + + def guess_can_open(self, filename_or_obj): + return isinstance(filename_or_obj, str) and is_remote_uri(filename_or_obj) def open_dataset( self, filename_or_obj, mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, + decode_times=True, + concat_characters=True, + decode_coords=True, drop_variables=None, use_cftime=None, decode_timedelta=None, session=None, + lock=None, ): + # TODO remove after v0.19 + if lock is not None: + warnings.warn( + "The kwarg 'lock' has been deprecated for this backend, and is now " + "ignored. In the future passing lock will raise an error.", + DeprecationWarning, + ) + store = PydapDataStore.open( filename_or_obj, session=session, @@ -143,5 +156,4 @@ def open_dataset( return ds -if has_pydap: - BACKEND_ENTRYPOINTS["pydap"] = PydapBackendEntrypoint +BACKEND_ENTRYPOINTS["pydap"] = PydapBackendEntrypoint diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 41c99efd076..4e912f3e1ef 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -8,6 +8,7 @@ AbstractDataStore, BackendArray, BackendEntrypoint, + _normalize_path, ) from .file_manager import CachingFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, SerializableLock, combine_locks, ensure_lock @@ -74,7 +75,7 @@ def ds(self): return self._manager.acquire() def open_store_variable(self, name, var): - data = indexing.LazilyOuterIndexedArray(NioArrayWrapper(name, self)) + data = indexing.LazilyIndexedArray(NioArrayWrapper(name, self)) return Variable(var.dimensions, data, var.attributes) def get_variables(self): @@ -98,18 +99,22 @@ def close(self): class PynioBackendEntrypoint(BackendEntrypoint): + available = has_pynio + def open_dataset( + self, filename_or_obj, mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, + decode_times=True, + concat_characters=True, + decode_coords=True, drop_variables=None, use_cftime=None, decode_timedelta=None, mode="r", lock=None, ): + filename_or_obj = _normalize_path(filename_or_obj) store = NioDataStore( filename_or_obj, mode=mode, @@ -131,5 +136,4 @@ def open_dataset( return ds -if has_pynio: - BACKEND_ENTRYPOINTS["pynio"] = PynioBackendEntrypoint +BACKEND_ENTRYPOINTS["pynio"] = PynioBackendEntrypoint diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index c689c1e99d7..1891fac8668 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -52,9 +52,9 @@ def shape(self): def _get_indexer(self, key): """Get indexer for rasterio array. - Parameter - --------- - key: tuple of int + Parameters + ---------- + key : tuple of int Returns ------- @@ -63,7 +63,7 @@ def _get_indexer(self, key): squeeze_axis: axes to be squeezed np_ind: indexer for loaded numpy array - See also + See Also -------- indexing.decompose_indexer """ @@ -162,7 +162,14 @@ def default(s): return parsed_meta -def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, lock=None): +def open_rasterio( + filename, + parse_coordinates=None, + chunks=None, + cache=None, + lock=None, + **kwargs, +): """Open a file with rasterio (experimental). This should work with any file that rasterio can open (most often: @@ -174,12 +181,46 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc You can generate 2D coordinates from the file's attributes with:: - from affine import Affine - da = xr.open_rasterio('path_to_file.tif') - transform = Affine.from_gdal(*da.attrs['transform']) - nx, ny = da.sizes['x'], da.sizes['y'] - x, y = np.meshgrid(np.arange(nx)+0.5, np.arange(ny)+0.5) * transform - + >>> from affine import Affine + >>> da = xr.open_rasterio( + ... "https://github.com/mapbox/rasterio/raw/1.2.1/tests/data/RGB.byte.tif" + ... ) + >>> da + + [1703814 values with dtype=uint8] + Coordinates: + * band (band) int64 1 2 3 + * y (y) float64 2.827e+06 2.826e+06 2.826e+06 ... 2.612e+06 2.612e+06 + * x (x) float64 1.021e+05 1.024e+05 1.027e+05 ... 3.389e+05 3.392e+05 + Attributes: + transform: (300.0379266750948, 0.0, 101985.0, 0.0, -300.041782729805... + crs: +init=epsg:32618 + res: (300.0379266750948, 300.041782729805) + is_tiled: 0 + nodatavals: (0.0, 0.0, 0.0) + scales: (1.0, 1.0, 1.0) + offsets: (0.0, 0.0, 0.0) + AREA_OR_POINT: Area + >>> transform = Affine(*da.attrs["transform"]) + >>> transform + Affine(300.0379266750948, 0.0, 101985.0, + 0.0, -300.041782729805, 2826915.0) + >>> nx, ny = da.sizes["x"], da.sizes["y"] + >>> x, y = transform * np.meshgrid(np.arange(nx) + 0.5, np.arange(ny) + 0.5) + >>> x + array([[102135.01896334, 102435.05689001, 102735.09481669, ..., + 338564.90518331, 338864.94310999, 339164.98103666], + [102135.01896334, 102435.05689001, 102735.09481669, ..., + 338564.90518331, 338864.94310999, 339164.98103666], + [102135.01896334, 102435.05689001, 102735.09481669, ..., + 338564.90518331, 338864.94310999, 339164.98103666], + ..., + [102135.01896334, 102435.05689001, 102735.09481669, ..., + 338564.90518331, 338864.94310999, 339164.98103666], + [102135.01896334, 102435.05689001, 102735.09481669, ..., + 338564.90518331, 338864.94310999, 339164.98103666], + [102135.01896334, 102435.05689001, 102735.09481669, ..., + 338564.90518331, 338864.94310999, 339164.98103666]]) Parameters ---------- @@ -238,7 +279,13 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc if lock is None: lock = RASTERIO_LOCK - manager = CachingFileManager(rasterio.open, filename, lock=lock, mode="r") + manager = CachingFileManager( + rasterio.open, + filename, + lock=lock, + mode="r", + kwargs=kwargs, + ) riods = manager.acquire() if vrt_params is not None: riods = WarpedVRT(riods, **vrt_params) @@ -336,9 +383,7 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc else: attrs[k] = v - data = indexing.LazilyOuterIndexedArray( - RasterioArrayWrapper(manager, lock, vrt_params) - ) + data = indexing.LazilyIndexedArray(RasterioArrayWrapper(manager, lock, vrt_params)) # this lets you write arrays loaded with rasterio data = indexing.CopyOnWriteArray(data) @@ -357,7 +402,7 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc # the filename is probably an s3 bucket rather than a regular file mtime = None token = tokenize(filename, mtime, chunks) - name_prefix = "open_rasterio-%s" % token + name_prefix = f"open_rasterio-{token}" result = result.chunk(chunks, name_prefix=name_prefix, token=token) # Make the file closeable diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index ddc157ed8e4..4c1ce1ef09d 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -1,16 +1,23 @@ +import gzip import io import os import numpy as np from ..core.indexing import NumpyIndexingAdapter -from ..core.utils import Frozen, FrozenDict, close_on_error, read_magic_number +from ..core.utils import ( + Frozen, + FrozenDict, + close_on_error, + try_read_magic_number_from_file_or_path, +) from ..core.variable import Variable from .common import ( BACKEND_ENTRYPOINTS, BackendArray, BackendEntrypoint, WritableCFDataStore, + _normalize_path, ) from .file_manager import CachingFileManager, DummyFileManager from .locks import ensure_lock, get_write_lock @@ -71,8 +78,6 @@ def __setitem__(self, key, value): def _open_scipy_netcdf(filename, mode, mmap, version): - import gzip - # if the string ends with .gz, then gunzip and open as netcdf file if isinstance(filename, str) and filename.endswith(".gz"): try: @@ -127,7 +132,7 @@ def __init__( elif format == "NETCDF3_CLASSIC": version = 1 else: - raise ValueError("invalid format for scipy.io.netcdf backend: %r" % format) + raise ValueError(f"invalid format for scipy.io.netcdf backend: {format!r}") if lock is None and mode != "r" and isinstance(filename_or_obj, str): lock = get_write_lock(filename_or_obj) @@ -173,16 +178,14 @@ def get_dimensions(self): return Frozen(self.ds.dimensions) def get_encoding(self): - encoding = {} - encoding["unlimited_dims"] = { - k for k, v in self.ds.dimensions.items() if v is None + return { + "unlimited_dims": {k for k, v in self.ds.dimensions.items() if v is None} } - return encoding def set_dimension(self, name, length, is_unlimited=False): if name in self.ds.dimensions: raise ValueError( - "%s does not support modifying dimensions" % type(self).__name__ + f"{type(self).__name__} does not support modifying dimensions" ) dim_length = length if not is_unlimited else None self.ds.createDimension(name, dim_length) @@ -203,12 +206,14 @@ def encode_variable(self, variable): def prepare_variable( self, name, variable, check_encoding=False, unlimited_dims=None ): - if check_encoding and variable.encoding: - if variable.encoding != {"_FillValue": None}: - raise ValueError( - "unexpected encoding for scipy backend: %r" - % list(variable.encoding) - ) + if ( + check_encoding + and variable.encoding + and variable.encoding != {"_FillValue": None} + ): + raise ValueError( + f"unexpected encoding for scipy backend: {list(variable.encoding)}" + ) data = variable.data # nb. this still creates a numpy array in all memory, even though we @@ -233,14 +238,19 @@ def close(self): class ScipyBackendEntrypoint(BackendEntrypoint): - def guess_can_open(self, store_spec): - try: - return read_magic_number(store_spec).startswith(b"CDF") - except TypeError: - pass + available = has_scipy + + def guess_can_open(self, filename_or_obj): + + magic_number = try_read_magic_number_from_file_or_path(filename_or_obj) + if magic_number is not None and magic_number.startswith(b"\x1f\x8b"): + with gzip.open(filename_or_obj) as f: + magic_number = try_read_magic_number_from_file_or_path(f) + if magic_number is not None: + return magic_number.startswith(b"CDF") try: - _, ext = os.path.splitext(store_spec) + _, ext = os.path.splitext(filename_or_obj) except TypeError: return False return ext in {".nc", ".nc4", ".cdf", ".gz"} @@ -249,9 +259,9 @@ def open_dataset( self, filename_or_obj, mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, + decode_times=True, + concat_characters=True, + decode_coords=True, drop_variables=None, use_cftime=None, decode_timedelta=None, @@ -262,6 +272,7 @@ def open_dataset( lock=None, ): + filename_or_obj = _normalize_path(filename_or_obj) store = ScipyDataStore( filename_or_obj, mode=mode, format=format, group=group, mmap=mmap, lock=lock ) @@ -281,5 +292,4 @@ def open_dataset( return ds -if has_scipy: - BACKEND_ENTRYPOINTS["scipy"] = ScipyBackendEntrypoint +BACKEND_ENTRYPOINTS["scipy"] = ScipyBackendEntrypoint diff --git a/xarray/backends/store.py b/xarray/backends/store.py index d57b3ab9df8..b774d2bce95 100644 --- a/xarray/backends/store.py +++ b/xarray/backends/store.py @@ -4,8 +4,10 @@ class StoreBackendEntrypoint(BackendEntrypoint): - def guess_can_open(self, store_spec): - return isinstance(store_spec, AbstractDataStore) + available = True + + def guess_can_open(self, filename_or_obj): + return isinstance(filename_or_obj, AbstractDataStore) def open_dataset( self, diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 1d667a38b53..aec12d2b154 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -1,5 +1,7 @@ import os import pathlib +import warnings +from distutils.version import LooseVersion import numpy as np @@ -14,6 +16,7 @@ BackendArray, BackendEntrypoint, _encode_variable_name, + _normalize_path, ) from .store import StoreBackendEntrypoint @@ -64,7 +67,7 @@ def __init__(self, variable_name, datastore): self.dtype = dtype def get_array(self): - return self.datastore.ds[self.variable_name] + return self.datastore.zarr_group[self.variable_name] def __getitem__(self, key): array = self.get_array() @@ -81,7 +84,7 @@ def __getitem__(self, key): # could possibly have a work-around for 0d data here -def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name): +def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks): """ Given encoding chunks (possibly None) and variable chunks (possibly None) """ @@ -131,7 +134,7 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name): if len(enc_chunks_tuple) != ndim: # throw away encoding chunks, start over - return _determine_zarr_chunks(None, var_chunks, ndim, name) + return _determine_zarr_chunks(None, var_chunks, ndim, name, safe_chunks) for x in enc_chunks_tuple: if not isinstance(x, int): @@ -162,24 +165,32 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name): continue for dchunk in dchunks[:-1]: if dchunk % zchunk: - raise NotImplementedError( + base_error = ( f"Specified zarr chunks encoding['chunks']={enc_chunks_tuple!r} for " f"variable named {name!r} would overlap multiple dask chunks {var_chunks!r}. " - "This is not implemented in xarray yet. " - "Consider either rechunking using `chunk()` or instead deleting " - "or modifying `encoding['chunks']`." + f"Writing this array in parallel with dask could lead to corrupted data." ) + if safe_chunks: + raise NotImplementedError( + base_error + + " Consider either rechunking using `chunk()`, deleting " + "or modifying `encoding['chunks']`, or specify `safe_chunks=False`." + ) if dchunks[-1] > zchunk: - raise ValueError( + base_error = ( "Final chunk of Zarr array must be the same size or " "smaller than the first. " f"Specified Zarr chunk encoding['chunks']={enc_chunks_tuple}, " f"for variable named {name!r} " - f"but {dchunks} in the variable's Dask chunks {var_chunks} is " + f"but {dchunks} in the variable's Dask chunks {var_chunks} are " "incompatible with this encoding. " - "Consider either rechunking using `chunk()` or instead deleting " - "or modifying `encoding['chunks']`." ) + if safe_chunks: + raise NotImplementedError( + base_error + + " Consider either rechunking using `chunk()`, deleting " + "or modifying `encoding['chunks']`, or specify `safe_chunks=False`." + ) return enc_chunks_tuple raise AssertionError("We should never get here. Function logic must be wrong.") @@ -194,14 +205,16 @@ def _get_zarr_dims_and_attrs(zarr_obj, dimension_key): dimensions = zarr_obj.attrs[dimension_key] except KeyError: raise KeyError( - "Zarr object is missing the attribute `%s`, which is " - "required for xarray to determine variable dimensions." % (dimension_key) + f"Zarr object is missing the attribute `{dimension_key}`, which is " + "required for xarray to determine variable dimensions." ) attributes = HiddenKeyDict(zarr_obj.attrs, [dimension_key]) return dimensions, attributes -def extract_zarr_variable_encoding(variable, raise_on_invalid=False, name=None): +def extract_zarr_variable_encoding( + variable, raise_on_invalid=False, name=None, safe_chunks=True +): """ Extract zarr encoding dictionary from xarray Variable @@ -223,7 +236,7 @@ def extract_zarr_variable_encoding(variable, raise_on_invalid=False, name=None): invalid = [k for k in encoding if k not in valid_encodings] if invalid: raise ValueError( - "unexpected encoding parameters for zarr backend: %r" % invalid + f"unexpected encoding parameters for zarr backend: {invalid!r}" ) else: for k in list(encoding): @@ -231,7 +244,7 @@ def extract_zarr_variable_encoding(variable, raise_on_invalid=False, name=None): del encoding[k] chunks = _determine_zarr_chunks( - encoding.get("chunks"), variable.chunks, variable.ndim, name + encoding.get("chunks"), variable.chunks, variable.ndim, name, safe_chunks ) encoding["chunks"] = chunks return encoding @@ -272,17 +285,47 @@ def encode_zarr_variable(var, needs_copy=True, name=None): return var +def _validate_existing_dims(var_name, new_var, existing_var, region, append_dim): + if new_var.dims != existing_var.dims: + raise ValueError( + f"variable {var_name!r} already exists with different " + f"dimension names {existing_var.dims} != " + f"{new_var.dims}, but changing variable " + f"dimensions is not supported by to_zarr()." + ) + + existing_sizes = {} + for dim, size in existing_var.sizes.items(): + if region is not None and dim in region: + start, stop, stride = region[dim].indices(size) + assert stride == 1 # region was already validated + size = stop - start + if dim != append_dim: + existing_sizes[dim] = size + + new_sizes = {dim: size for dim, size in new_var.sizes.items() if dim != append_dim} + if existing_sizes != new_sizes: + raise ValueError( + f"variable {var_name!r} already exists with different " + f"dimension sizes: {existing_sizes} != {new_sizes}. " + f"to_zarr() only supports changing dimension sizes when " + f"explicitly appending, but append_dim={append_dim!r}." + ) + + class ZarrStore(AbstractWritableDataStore): """Store for reading and writing data via zarr""" __slots__ = ( - "ds", + "zarr_group", "_append_dim", "_consolidate_on_close", "_group", + "_mode", "_read_only", "_synchronizer", "_write_region", + "_safe_chunks", ) @classmethod @@ -295,38 +338,92 @@ def open_group( consolidated=False, consolidate_on_close=False, chunk_store=None, + storage_options=None, append_dim=None, write_region=None, + safe_chunks=True, + stacklevel=2, ): # zarr doesn't support pathlib.Path objects yet. zarr-python#601 if isinstance(store, pathlib.Path): store = os.fspath(store) - open_kwargs = dict(mode=mode, synchronizer=synchronizer, path=group) + open_kwargs = dict( + mode=mode, + synchronizer=synchronizer, + path=group, + ) + if LooseVersion(zarr.__version__) >= "2.5.0": + open_kwargs["storage_options"] = storage_options + elif storage_options: + raise ValueError("Storage options only compatible with zarr>=2.5.0") + if chunk_store: open_kwargs["chunk_store"] = chunk_store + if consolidated is None: + consolidated = False - if consolidated: + if consolidated is None: + try: + zarr_group = zarr.open_consolidated(store, **open_kwargs) + except KeyError: + warnings.warn( + "Failed to open Zarr store with consolidated metadata, " + "falling back to try reading non-consolidated metadata. " + "This is typically much slower for opening a dataset. " + "To silence this warning, consider:\n" + "1. Consolidating metadata in this existing store with " + "zarr.consolidate_metadata().\n" + "2. Explicitly setting consolidated=False, to avoid trying " + "to read consolidate metadata, or\n" + "3. Explicitly setting consolidated=True, to raise an " + "error in this case instead of falling back to try " + "reading non-consolidated metadata.", + RuntimeWarning, + stacklevel=stacklevel, + ) + zarr_group = zarr.open_group(store, **open_kwargs) + elif consolidated: # TODO: an option to pass the metadata_key keyword zarr_group = zarr.open_consolidated(store, **open_kwargs) else: zarr_group = zarr.open_group(store, **open_kwargs) - return cls(zarr_group, consolidate_on_close, append_dim, write_region) + return cls( + zarr_group, + mode, + consolidate_on_close, + append_dim, + write_region, + safe_chunks, + ) def __init__( - self, zarr_group, consolidate_on_close=False, append_dim=None, write_region=None + self, + zarr_group, + mode=None, + consolidate_on_close=False, + append_dim=None, + write_region=None, + safe_chunks=True, ): - self.ds = zarr_group - self._read_only = self.ds.read_only - self._synchronizer = self.ds.synchronizer - self._group = self.ds.path + self.zarr_group = zarr_group + self._read_only = self.zarr_group.read_only + self._synchronizer = self.zarr_group.synchronizer + self._group = self.zarr_group.path + self._mode = mode self._consolidate_on_close = consolidate_on_close self._append_dim = append_dim self._write_region = write_region + self._safe_chunks = safe_chunks + + @property + def ds(self): + # TODO: consider deprecating this in favor of zarr_group + return self.zarr_group def open_store_variable(self, name, zarr_array): - data = indexing.LazilyOuterIndexedArray(ZarrArrayWrapper(name, self)) + data = indexing.LazilyIndexedArray(ZarrArrayWrapper(name, self)) dimensions, attributes = _get_zarr_dims_and_attrs(zarr_array, DIMENSION_KEY) attributes = dict(attributes) encoding = { @@ -344,30 +441,29 @@ def open_store_variable(self, name, zarr_array): def get_variables(self): return FrozenDict( - (k, self.open_store_variable(k, v)) for k, v in self.ds.arrays() + (k, self.open_store_variable(k, v)) for k, v in self.zarr_group.arrays() ) def get_attrs(self): - attributes = dict(self.ds.attrs.asdict()) - return attributes + return dict(self.zarr_group.attrs.asdict()) def get_dimensions(self): dimensions = {} - for k, v in self.ds.arrays(): + for k, v in self.zarr_group.arrays(): try: for d, s in zip(v.attrs[DIMENSION_KEY], v.shape): if d in dimensions and dimensions[d] != s: raise ValueError( - "found conflicting lengths for dimension %s " - "(%d != %d)" % (d, s, dimensions[d]) + f"found conflicting lengths for dimension {d} " + f"({s} != {dimensions[d]})" ) dimensions[d] = s except KeyError: raise KeyError( - "Zarr object is missing the attribute `%s`, " + f"Zarr object is missing the attribute `{DIMENSION_KEY}`, " "which is required for xarray to determine " - "variable dimensions." % (DIMENSION_KEY) + "variable dimensions." ) return dimensions @@ -378,7 +474,7 @@ def set_dimensions(self, variables, unlimited_dims=None): ) def set_attributes(self, attributes): - self.ds.attrs.put(attributes) + self.zarr_group.attrs.put(attributes) def encode_variable(self, variable): variable = encode_zarr_variable(variable) @@ -417,35 +513,50 @@ def store( dimension on which the zarray will be appended only needed in append mode """ - - existing_variables = { - vn for vn in variables if _encode_variable_name(vn) in self.ds + existing_variable_names = { + vn for vn in variables if _encode_variable_name(vn) in self.zarr_group } - new_variables = set(variables) - existing_variables + new_variables = set(variables) - existing_variable_names variables_without_encoding = {vn: variables[vn] for vn in new_variables} variables_encoded, attributes = self.encode( variables_without_encoding, attributes ) - if len(existing_variables) > 0: - # there are variables to append - # their encoding must be the same as in the store - ds = open_zarr(self.ds.store, group=self.ds.path, chunks=None) - variables_with_encoding = {} - for vn in existing_variables: - variables_with_encoding[vn] = variables[vn].copy(deep=False) - variables_with_encoding[vn].encoding = ds[vn].encoding - variables_with_encoding, _ = self.encode(variables_with_encoding, {}) - variables_encoded.update(variables_with_encoding) - - if self._write_region is None: + if existing_variable_names: + # Decode variables directly, without going via xarray.Dataset to + # avoid needing to load index variables into memory. + # TODO: consider making loading indexes lazy again? + existing_vars, _, _ = conventions.decode_cf_variables( + self.get_variables(), self.get_attrs() + ) + # Modified variables must use the same encoding as the store. + vars_with_encoding = {} + for vn in existing_variable_names: + vars_with_encoding[vn] = variables[vn].copy(deep=False) + vars_with_encoding[vn].encoding = existing_vars[vn].encoding + vars_with_encoding, _ = self.encode(vars_with_encoding, {}) + variables_encoded.update(vars_with_encoding) + + for var_name in existing_variable_names: + new_var = variables_encoded[var_name] + existing_var = existing_vars[var_name] + _validate_existing_dims( + var_name, + new_var, + existing_var, + self._write_region, + self._append_dim, + ) + + if self._mode not in ["r", "r+"]: self.set_attributes(attributes) self.set_dimensions(variables_encoded, unlimited_dims=unlimited_dims) + self.set_variables( variables_encoded, check_encoding_set, writer, unlimited_dims=unlimited_dims ) if self._consolidate_on_close: - zarr.consolidate_metadata(self.ds.store) + zarr.consolidate_metadata(self.zarr_group.store) def sync(self): pass @@ -462,7 +573,7 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No check_encoding_set : list-like List of variables that should be checked for invalid encoding values - writer : + writer unlimited_dims : list-like List of dimension names that should be treated as unlimited dimensions. @@ -480,13 +591,16 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No if v.encoding == {"_FillValue": None} and fill_value is None: v.encoding = {} - if name in self.ds: + if name in self.zarr_group: # existing variable - zarr_array = self.ds[name] + # TODO: if mode="a", consider overriding the existing variable + # metadata. This would need some case work properly with region + # and append_dim. + zarr_array = self.zarr_group[name] else: # new variable encoding = extract_zarr_variable_encoding( - v, raise_on_invalid=check, name=vn + v, raise_on_invalid=check, name=vn, safe_chunks=self._safe_chunks ) encoded_attrs = {} # the magic for storing the hidden dimension data @@ -496,7 +610,7 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No if coding.strings.check_vlen_dtype(dtype) == str: dtype = str - zarr_array = self.ds.create( + zarr_array = self.zarr_group.create( name, shape=shape, dtype=dtype, fill_value=fill_value, **encoding ) zarr_array.attrs.put(encoded_attrs) @@ -534,19 +648,16 @@ def open_zarr( concat_characters=True, decode_coords=True, drop_variables=None, - consolidated=False, + consolidated=None, overwrite_encoded_chunks=False, chunk_store=None, + storage_options=None, decode_timedelta=None, use_cftime=None, **kwargs, ): """Load and decode a dataset from a Zarr store. - .. note:: Experimental - The Zarr backend is new and experimental. Please report any - unexpected behavior via github issues. - The `store` object should be a valid store for a Zarr group. `store` variables must contain dimension metadata encoded in the `_ARRAY_DIMENSIONS` attribute. @@ -566,7 +677,7 @@ def open_zarr( based on the variable's zarr chunks. If `chunks=None`, zarr array data will lazily convert to numpy arrays upon access. This accepts all the chunk specifications as Dask does. - overwrite_encoded_chunks: bool, optional + overwrite_encoded_chunks : bool, optional Whether to drop the zarr chunks encoded for each variable when a dataset is loaded with specified chunk sizes (default: False) decode_cf : bool, optional @@ -598,6 +709,8 @@ def open_zarr( consolidated : bool, optional Whether to open the store using zarr's consolidated metadata capability. Only works for stores that have already been consolidated. + By default (`consolidate=None`), attempts to read consolidated metadata, + falling back to read non-consolidated metadata if that fails. chunk_store : MutableMapping, optional A separate Zarr store only for chunk data. decode_timedelta : bool, optional @@ -605,7 +718,7 @@ def open_zarr( {'days', 'hours', 'minutes', 'seconds', 'milliseconds', 'microseconds'} into timedelta objects. If False, leave them encoded as numbers. If None (default), assume the same value of decode_time. - use_cftime: bool, optional + use_cftime : bool, optional Only relevant if encoded dates come from a standard calendar (e.g. "gregorian", "proleptic_gregorian", "standard", or not specified). If None (default), attempt to decode times to @@ -624,6 +737,7 @@ def open_zarr( See Also -------- open_dataset + open_mfdataset References ---------- @@ -649,6 +763,8 @@ def open_zarr( "consolidated": consolidated, "overwrite_encoded_chunks": overwrite_encoded_chunks, "chunk_store": chunk_store, + "storage_options": storage_options, + "stacklevel": 4, } ds = open_dataset( @@ -666,36 +782,57 @@ def open_zarr( decode_timedelta=decode_timedelta, use_cftime=use_cftime, ) - return ds class ZarrBackendEntrypoint(BackendEntrypoint): + available = has_zarr + + def guess_can_open(self, filename_or_obj): + try: + _, ext = os.path.splitext(filename_or_obj) + except TypeError: + return False + return ext in {".zarr"} + def open_dataset( self, filename_or_obj, mask_and_scale=True, - decode_times=None, - concat_characters=None, - decode_coords=None, + decode_times=True, + concat_characters=True, + decode_coords=True, drop_variables=None, use_cftime=None, decode_timedelta=None, group=None, mode="r", synchronizer=None, - consolidated=False, - consolidate_on_close=False, + consolidated=None, chunk_store=None, + storage_options=None, + stacklevel=3, + lock=None, ): + # TODO remove after v0.19 + if lock is not None: + warnings.warn( + "The kwarg 'lock' has been deprecated for this backend, and is now " + "ignored. In the future passing lock will raise an error.", + DeprecationWarning, + ) + + filename_or_obj = _normalize_path(filename_or_obj) store = ZarrStore.open_group( filename_or_obj, group=group, mode=mode, synchronizer=synchronizer, consolidated=consolidated, - consolidate_on_close=consolidate_on_close, + consolidate_on_close=False, chunk_store=chunk_store, + storage_options=storage_options, + stacklevel=stacklevel + 1, ) store_entrypoint = StoreBackendEntrypoint() @@ -713,5 +850,4 @@ def open_dataset( return ds -if has_zarr: - BACKEND_ENTRYPOINTS["zarr"] = ZarrBackendEntrypoint +BACKEND_ENTRYPOINTS["zarr"] = ZarrBackendEntrypoint diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py index 3c92c816e12..c031bffb2cd 100644 --- a/xarray/coding/cftime_offsets.py +++ b/xarray/coding/cftime_offsets.py @@ -178,8 +178,7 @@ def _get_day_of_month(other, day_option): if day_option == "start": return 1 elif day_option == "end": - days_in_month = _days_in_month(other) - return days_in_month + return _days_in_month(other) elif day_option is None: # Note: unlike `_shift_month`, _get_day_of_month does not # allow day_option = None @@ -291,10 +290,7 @@ def roll_qtrday(other, n, month, day_option, modby=3): def _validate_month(month, default_month): - if month is None: - result_month = default_month - else: - result_month = month + result_month = default_month if month is None else month if not isinstance(result_month, int): raise TypeError( "'self.month' must be an integer value between 1 " @@ -576,6 +572,26 @@ def __apply__(self, other): return other + self.as_timedelta() +class Millisecond(BaseCFTimeOffset): + _freq = "L" + + def as_timedelta(self): + return timedelta(milliseconds=self.n) + + def __apply__(self, other): + return other + self.as_timedelta() + + +class Microsecond(BaseCFTimeOffset): + _freq = "U" + + def as_timedelta(self): + return timedelta(microseconds=self.n) + + def __apply__(self, other): + return other + self.as_timedelta() + + _FREQUENCIES = { "A": YearEnd, "AS": YearBegin, @@ -590,6 +606,10 @@ def __apply__(self, other): "T": Minute, "min": Minute, "S": Second, + "L": Millisecond, + "ms": Millisecond, + "U": Microsecond, + "us": Microsecond, "AS-JAN": partial(YearBegin, month=1), "AS-FEB": partial(YearBegin, month=2), "AS-MAR": partial(YearBegin, month=3), @@ -663,11 +683,7 @@ def to_offset(freq): freq = freq_data["freq"] multiples = freq_data["multiple"] - if multiples is None: - multiples = 1 - else: - multiples = int(multiples) - + multiples = 1 if multiples is None else int(multiples) return _FREQUENCIES[freq](n=multiples) @@ -796,7 +812,7 @@ def cftime_range( periods : int, optional Number of periods to generate. freq : str or None, default: "D" - Frequency strings can have multiples, e.g. "5H". + Frequency strings can have multiples, e.g. "5H". normalize : bool, default: False Normalize start/end dates to midnight before generating date range. name : str, default: None @@ -813,7 +829,6 @@ def cftime_range( Notes ----- - This function is an analog of ``pandas.date_range`` for use in generating sequences of ``cftime.datetime`` objects. It supports most of the features of ``pandas.date_range`` (e.g. specifying how the index is @@ -825,7 +840,7 @@ def cftime_range( `ISO-8601 format `_. - It supports many, but not all, frequencies supported by ``pandas.date_range``. For example it does not currently support any of - the business-related, semi-monthly, or sub-second frequencies. + the business-related or semi-monthly frequencies. - Compound sub-monthly frequencies are not supported, e.g. '1H1min', as these can easily be written in terms of the finest common resolution, e.g. '61min'. @@ -856,6 +871,10 @@ def cftime_range( +--------+--------------------------+ | S | Second frequency | +--------+--------------------------+ + | L, ms | Millisecond frequency | + +--------+--------------------------+ + | U, us | Microsecond frequency | + +--------+--------------------------+ Any multiples of the following anchored offsets are also supported. @@ -911,7 +930,6 @@ def cftime_range( | Q(S)-DEC | Quarter frequency, anchored at the end (or beginning) of December | +----------+--------------------------------------------------------------------+ - Finally, the following calendar aliases are supported. +--------------------------------+---------------------------------------+ @@ -932,7 +950,6 @@ def cftime_range( Examples -------- - This function returns a ``CFTimeIndex``, populated with ``cftime.datetime`` objects associated with the specified calendar type, e.g. diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index e414740d420..783fe8d04d9 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -43,6 +43,7 @@ import warnings from datetime import timedelta from distutils.version import LooseVersion +from typing import Tuple, Type import numpy as np import pandas as pd @@ -59,6 +60,13 @@ REPR_ELLIPSIS_SHOW_ITEMS_FRONT_END = 10 +OUT_OF_BOUNDS_TIMEDELTA_ERRORS: Tuple[Type[Exception], ...] +try: + OUT_OF_BOUNDS_TIMEDELTA_ERRORS = (pd.errors.OutOfBoundsTimedelta, OverflowError) +except AttributeError: + OUT_OF_BOUNDS_TIMEDELTA_ERRORS = (OverflowError,) + + def named(name, pattern): return "(?P<" + name + ">" + pattern + ")" @@ -247,7 +255,7 @@ def format_times( indent = first_row_offset if row == 0 else offset row_end = last_row_end if row == n_rows - 1 else intermediate_row_end times_for_row = index[row * n_per_row : (row + 1) * n_per_row] - representation = representation + format_row( + representation += format_row( times_for_row, indent=indent, separator=separator, row_end=row_end ) @@ -260,8 +268,9 @@ def format_attrs(index, separator=", "): "dtype": f"'{index.dtype}'", "length": f"{len(index)}", "calendar": f"'{index.calendar}'", + "freq": f"'{index.freq}'" if len(index) >= 3 else None, } - attrs["freq"] = f"'{index.freq}'" if len(index) >= 3 else None + attrs_str = [f"{k}={v}" for k, v in attrs.items()] attrs_str = f"{separator}".join(attrs_str) return attrs_str @@ -342,14 +351,13 @@ def __repr__(self): attrs_str = format_attrs(self) # oneliner only if smaller than display_width full_repr_str = f"{klass_name}([{datastr}], {attrs_str})" - if len(full_repr_str) <= display_width: - return full_repr_str - else: + if len(full_repr_str) > display_width: # if attrs_str too long, one per line if len(attrs_str) >= display_width - offset: attrs_str = attrs_str.replace(",", f",\n{' '*(offset-2)}") full_repr_str = f"{klass_name}([{datastr}],\n{' '*(offset-1)}{attrs_str})" - return full_repr_str + + return full_repr_str def _partial_date_slice(self, resolution, parsed): """Adapted from @@ -363,8 +371,6 @@ def _partial_date_slice(self, resolution, parsed): defining the index. For example: >>> from cftime import DatetimeNoLeap - >>> import pandas as pd - >>> import xarray as xr >>> da = xr.DataArray( ... [1, 2], ... coords=[[DatetimeNoLeap(2001, 1, 1), DatetimeNoLeap(2001, 2, 1)]], @@ -459,18 +465,23 @@ def get_loc(self, key, method=None, tolerance=None): else: return pd.Index.get_loc(self, key, method=method, tolerance=tolerance) - def _maybe_cast_slice_bound(self, label, side, kind): + def _maybe_cast_slice_bound(self, label, side, kind=None): """Adapted from - pandas.tseries.index.DatetimeIndex._maybe_cast_slice_bound""" - if isinstance(label, str): - parsed, resolution = _parse_iso8601_with_reso(self.date_type, label) - start, end = _parsed_string_to_bounds(self.date_type, resolution, parsed) - if self.is_monotonic_decreasing and len(self) > 1: - return end if side == "left" else start - return start if side == "left" else end - else: + pandas.tseries.index.DatetimeIndex._maybe_cast_slice_bound + + Note that we have never used the kind argument in CFTimeIndex and it is + deprecated as of pandas version 1.3.0. It exists only for compatibility + reasons. We can remove it when our minimum version of pandas is 1.3.0. + """ + if not isinstance(label, str): return label + parsed, resolution = _parse_iso8601_with_reso(self.date_type, label) + start, end = _parsed_string_to_bounds(self.date_type, resolution, parsed) + if self.is_monotonic_decreasing and len(self) > 1: + return end if side == "left" else start + return start if side == "left" else end + # TODO: Add ability to use integer range outside of iloc? # e.g. series[1:5]. def get_value(self, series, key): @@ -516,7 +527,7 @@ def shift(self, n, freq): ------- CFTimeIndex - See also + See Also -------- pandas.DatetimeIndex.shift @@ -562,7 +573,7 @@ def __sub__(self, other): elif _contains_cftime_datetimes(np.array(other)): try: return pd.TimedeltaIndex(np.array(self) - np.array(other)) - except OverflowError: + except OUT_OF_BOUNDS_TIMEDELTA_ERRORS: raise ValueError( "The time difference exceeds the range of values " "that can be expressed at the nanosecond resolution." @@ -573,7 +584,7 @@ def __sub__(self, other): def __rsub__(self, other): try: return pd.TimedeltaIndex(other - np.array(self)) - except OverflowError: + except OUT_OF_BOUNDS_TIMEDELTA_ERRORS: raise ValueError( "The time difference exceeds the range of values " "that can be expressed at the nanosecond resolution." @@ -611,7 +622,6 @@ def to_datetimeindex(self, unsafe=False): Examples -------- - >>> import xarray as xr >>> times = xr.cftime_range("2000", periods=2, calendar="gregorian") >>> times CFTimeIndex([2000-01-01 00:00:00, 2000-01-02 00:00:00], diff --git a/xarray/coding/frequencies.py b/xarray/coding/frequencies.py index fa11d05923f..e9efef8eb7a 100644 --- a/xarray/coding/frequencies.py +++ b/xarray/coding/frequencies.py @@ -62,8 +62,8 @@ def infer_freq(index): Parameters ---------- index : CFTimeIndex, DataArray, DatetimeIndex, TimedeltaIndex, Series - If not passed a CFTimeIndex, this simply calls `pandas.infer_freq`. - If passed a Series or a DataArray will use the values of the series (NOT THE INDEX). + If not passed a CFTimeIndex, this simply calls `pandas.infer_freq`. + If passed a Series or a DataArray will use the values of the series (NOT THE INDEX). Returns ------- @@ -187,7 +187,7 @@ def _get_quartely_rule(self): if len(self.month_deltas) > 1: return None - if not self.month_deltas[0] % 3 == 0: + if self.month_deltas[0] % 3 != 0: return None return {"cs": "QS", "ce": "Q"}.get(month_anchor_check(self.index)) @@ -259,8 +259,7 @@ def month_anchor_check(dates): if calendar_end: cal = date.day == date.daysinmonth - if calendar_end: - calendar_end &= cal + calendar_end &= cal elif not calendar_start: break diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py index e16e983fd8a..c217cb0c865 100644 --- a/xarray/coding/strings.py +++ b/xarray/coding/strings.py @@ -111,7 +111,7 @@ def encode(self, variable, name=None): if "char_dim_name" in encoding.keys(): char_dim_name = encoding.pop("char_dim_name") else: - char_dim_name = "string%s" % data.shape[-1] + char_dim_name = f"string{data.shape[-1]}" dims = dims + (char_dim_name,) return Variable(dims, data, attrs, encoding) @@ -140,8 +140,7 @@ def bytes_to_char(arr): chunks=arr.chunks + ((arr.dtype.itemsize,)), new_axis=[arr.ndim], ) - else: - return _numpy_bytes_to_char(arr) + return _numpy_bytes_to_char(arr) def _numpy_bytes_to_char(arr): diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 3d877a169f5..f62a3961207 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -1,6 +1,6 @@ import re import warnings -from datetime import datetime +from datetime import datetime, timedelta from distutils.version import LooseVersion from functools import partial @@ -35,6 +35,26 @@ "D": int(1e9) * 60 * 60 * 24, } +_US_PER_TIME_DELTA = { + "microseconds": 1, + "milliseconds": 1_000, + "seconds": 1_000_000, + "minutes": 60 * 1_000_000, + "hours": 60 * 60 * 1_000_000, + "days": 24 * 60 * 60 * 1_000_000, +} + +_NETCDF_TIME_UNITS_CFTIME = [ + "days", + "hours", + "minutes", + "seconds", + "milliseconds", + "microseconds", +] + +_NETCDF_TIME_UNITS_NUMPY = _NETCDF_TIME_UNITS_CFTIME + ["nanoseconds"] + TIME_UNITS = frozenset( [ "days", @@ -48,10 +68,14 @@ ) +def _is_standard_calendar(calendar): + return calendar.lower() in _STANDARD_CALENDARS + + def _netcdf_to_numpy_timeunit(units): units = units.lower() if not units.endswith("s"): - units = "%ss" % units + units = f"{units}s" return { "nanoseconds": "ns", "microseconds": "us", @@ -80,6 +104,8 @@ def _ensure_padded_year(ref_date): # No four-digit strings, assume the first digits are the year and pad # appropriately matches_start_digits = re.match(r"(\d+)(.*)", ref_date) + if not matches_start_digits: + raise ValueError(f"invalid reference date for time units: {ref_date}") ref_year, everything_else = [s for s in matches_start_digits.groups()] ref_date_padded = "{:04d}{}".format(int(ref_year), everything_else) @@ -123,7 +149,7 @@ def _decode_cf_datetime_dtype(data, units, calendar, use_cftime): result = decode_cf_datetime(example_value, units, calendar, use_cftime) except Exception: calendar_msg = ( - "the default calendar" if calendar is None else "calendar %r" % calendar + "the default calendar" if calendar is None else f"calendar {calendar!r}" ) msg = ( f"unable to decode time units {units!r} with {calendar_msg!r}. Try " @@ -146,7 +172,7 @@ def _decode_datetime_with_cftime(num_dates, units, calendar): def _decode_datetime_with_pandas(flat_num_dates, units, calendar): - if calendar not in _STANDARD_CALENDARS: + if not _is_standard_calendar(calendar): raise OutOfBoundsDatetime( "Cannot decode times from a non-standard calendar, {!r}, using " "pandas.".format(calendar) @@ -161,6 +187,11 @@ def _decode_datetime_with_pandas(flat_num_dates, units, calendar): # strings, in which case we fall back to using cftime raise OutOfBoundsDatetime + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "invalid value encountered", RuntimeWarning) + pd.to_timedelta(flat_num_dates.min(), delta) + ref_date + pd.to_timedelta(flat_num_dates.max(), delta) + ref_date + # To avoid integer overflow when converting to nanosecond units for integer # dtypes smaller than np.int64 cast all integer-dtype arrays to np.int64 # (GH 2002). @@ -191,7 +222,7 @@ def decode_cf_datetime(num_dates, units, calendar=None, use_cftime=None): Note that time unit in `units` must not be smaller than microseconds and not larger than days. - See also + See Also -------- cftime.num2date """ @@ -212,7 +243,7 @@ def decode_cf_datetime(num_dates, units, calendar=None, use_cftime=None): dates[np.nanargmin(num_dates)].year < 1678 or dates[np.nanargmax(num_dates)].year >= 2262 ): - if calendar in _STANDARD_CALENDARS: + if _is_standard_calendar(calendar): warnings.warn( "Unable to decode time axis into full " "numpy.datetime64 objects, continuing using " @@ -222,12 +253,10 @@ def decode_cf_datetime(num_dates, units, calendar=None, use_cftime=None): stacklevel=3, ) else: - if calendar in _STANDARD_CALENDARS: + if _is_standard_calendar(calendar): dates = cftime_to_nptime(dates) elif use_cftime: - dates = _decode_datetime_with_cftime( - flat_num_dates.astype(float), units, calendar - ) + dates = _decode_datetime_with_cftime(flat_num_dates, units, calendar) else: dates = _decode_datetime_with_pandas(flat_num_dates, units, calendar) @@ -262,25 +291,33 @@ def decode_cf_timedelta(num_timedeltas, units): return result.reshape(num_timedeltas.shape) +def _unit_timedelta_cftime(units): + return timedelta(microseconds=_US_PER_TIME_DELTA[units]) + + +def _unit_timedelta_numpy(units): + numpy_units = _netcdf_to_numpy_timeunit(units) + return np.timedelta64(_NS_PER_TIME_DELTA[numpy_units], "ns") + + def _infer_time_units_from_diff(unique_timedeltas): - # Note that the modulus operator was only implemented for np.timedelta64 - # arrays as of NumPy version 1.16.0. Once our minimum version of NumPy - # supported is greater than or equal to this we will no longer need to cast - # unique_timedeltas to a TimedeltaIndex. In the meantime, however, the - # modulus operator works for TimedeltaIndex objects. - unique_deltas_as_index = pd.TimedeltaIndex(unique_timedeltas) - for time_unit in [ - "days", - "hours", - "minutes", - "seconds", - "milliseconds", - "microseconds", - "nanoseconds", - ]: - delta_ns = _NS_PER_TIME_DELTA[_netcdf_to_numpy_timeunit(time_unit)] - unit_delta = np.timedelta64(delta_ns, "ns") - if np.all(unique_deltas_as_index % unit_delta == np.timedelta64(0, "ns")): + if unique_timedeltas.dtype == np.dtype("O"): + time_units = _NETCDF_TIME_UNITS_CFTIME + unit_timedelta = _unit_timedelta_cftime + zero_timedelta = timedelta(microseconds=0) + timedeltas = unique_timedeltas + else: + time_units = _NETCDF_TIME_UNITS_NUMPY + unit_timedelta = _unit_timedelta_numpy + zero_timedelta = np.timedelta64(0, "ns") + # Note that the modulus operator was only implemented for np.timedelta64 + # arrays as of NumPy version 1.16.0. Once our minimum version of NumPy + # supported is greater than or equal to this we will no longer need to cast + # unique_timedeltas to a TimedeltaIndex. In the meantime, however, the + # modulus operator works for TimedeltaIndex objects. + timedeltas = pd.TimedeltaIndex(unique_timedeltas) + for time_unit in time_units: + if np.all(timedeltas % unit_timedelta(time_unit) == zero_timedelta): return time_unit return "seconds" @@ -309,10 +346,6 @@ def infer_datetime_units(dates): reference_date = dates[0] if len(dates) > 0 else "1970-01-01" reference_date = format_cftime_datetime(reference_date) unique_timedeltas = np.unique(np.diff(dates)) - if unique_timedeltas.dtype == np.dtype("O"): - # Convert to np.timedelta64 objects using pandas to work around a - # NumPy casting bug: https://github.com/numpy/numpy/issues/11096 - unique_timedeltas = to_timedelta_unboxed(unique_timedeltas) units = _infer_time_units_from_diff(unique_timedeltas) return f"{units} since {reference_date}" @@ -339,8 +372,7 @@ def infer_timedelta_units(deltas): """ deltas = to_timedelta_unboxed(np.asarray(deltas).ravel()) unique_timedeltas = np.unique(deltas[pd.notnull(deltas)]) - units = _infer_time_units_from_diff(unique_timedeltas) - return units + return _infer_time_units_from_diff(unique_timedeltas) def cftime_to_nptime(times): @@ -391,7 +423,7 @@ def _encode_datetime_with_cftime(dates, units, calendar): def encode_datetime(d): return np.nan if d is None else cftime.date2num(d, units, calendar) - return np.vectorize(encode_datetime)(dates) + return np.array([encode_datetime(d) for d in dates.ravel()]).reshape(dates.shape) def cast_to_int_if_safe(num): @@ -407,7 +439,7 @@ def encode_cf_datetime(dates, units=None, calendar=None): Unlike `date2num`, this function can handle datetime64 arrays. - See also + See Also -------- cftime.date2num """ @@ -423,7 +455,7 @@ def encode_cf_datetime(dates, units=None, calendar=None): delta, ref_date = _unpack_netcdf_time_units(units) try: - if calendar not in _STANDARD_CALENDARS or dates.dtype.kind == "O": + if not _is_standard_calendar(calendar) or dates.dtype.kind == "O": # parse with cftime instead raise OutOfBoundsDatetime assert dates.dtype == "datetime64[ns]" diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index b035ff82086..1ebaab1be02 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -77,7 +77,6 @@ def __repr__(self): def lazy_elemwise_func(array, func, dtype): """Lazily apply an element-wise function to an array. - Parameters ---------- array : any valid value of Variable._data @@ -255,10 +254,10 @@ def encode(self, variable, name=None): if "scale_factor" in encoding or "add_offset" in encoding: dtype = _choose_float_dtype(data.dtype, "add_offset" in encoding) data = data.astype(dtype=dtype, copy=True) - if "add_offset" in encoding: - data -= pop_to(encoding, attrs, "add_offset", name=name) - if "scale_factor" in encoding: - data /= pop_to(encoding, attrs, "scale_factor", name=name) + if "add_offset" in encoding: + data -= pop_to(encoding, attrs, "add_offset", name=name) + if "scale_factor" in encoding: + data /= pop_to(encoding, attrs, "scale_factor", name=name) return Variable(dims, data, attrs, encoding) @@ -294,7 +293,7 @@ def encode(self, variable, name=None): # integer data should be treated as unsigned" if encoding.get("_Unsigned", "false") == "true": pop_to(encoding, attrs, "_Unsigned") - signed_dtype = np.dtype("i%s" % data.dtype.itemsize) + signed_dtype = np.dtype(f"i{data.dtype.itemsize}") if "_FillValue" in attrs: new_fill = signed_dtype.type(attrs["_FillValue"]) attrs["_FillValue"] = new_fill @@ -310,16 +309,24 @@ def decode(self, variable, name=None): if data.dtype.kind == "i": if unsigned == "true": - unsigned_dtype = np.dtype("u%s" % data.dtype.itemsize) + unsigned_dtype = np.dtype(f"u{data.dtype.itemsize}") transform = partial(np.asarray, dtype=unsigned_dtype) data = lazy_elemwise_func(data, transform, unsigned_dtype) if "_FillValue" in attrs: new_fill = unsigned_dtype.type(attrs["_FillValue"]) attrs["_FillValue"] = new_fill + elif data.dtype.kind == "u": + if unsigned == "false": + signed_dtype = np.dtype(f"i{data.dtype.itemsize}") + transform = partial(np.asarray, dtype=signed_dtype) + data = lazy_elemwise_func(data, transform, signed_dtype) + if "_FillValue" in attrs: + new_fill = signed_dtype.type(attrs["_FillValue"]) + attrs["_FillValue"] = new_fill else: warnings.warn( - "variable %r has _Unsigned attribute but is not " - "of integer type. Ignoring attribute." % name, + f"variable {name!r} has _Unsigned attribute but is not " + "of integer type. Ignoring attribute.", SerializationWarning, stacklevel=3, ) diff --git a/xarray/conventions.py b/xarray/conventions.py index e33ae53b31d..c3a05e42f82 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -11,6 +11,23 @@ from .core.pycompat import is_duck_dask_array from .core.variable import IndexVariable, Variable, as_variable +CF_RELATED_DATA = ( + "bounds", + "grid_mapping", + "climatology", + "geometry", + "node_coordinates", + "node_count", + "part_node_count", + "interior_ring", + "cell_measures", + "formula_terms", +) +CF_RELATED_DATA_NEEDS_PARSING = ( + "cell_measures", + "formula_terms", +) + class NativeEndiannessArray(indexing.ExplicitlyIndexedNDArrayMixin): """Decode arrays on the fly from non-native to native endianness @@ -93,9 +110,9 @@ def maybe_encode_nonstring_dtype(var, name=None): and "missing_value" not in var.attrs ): warnings.warn( - "saving variable %s with floating " + f"saving variable {name} with floating " "point data as an integer dtype without " - "any _FillValue to use for NaNs" % name, + "any _FillValue to use for NaNs", SerializationWarning, stacklevel=10, ) @@ -256,6 +273,9 @@ def encode_cf_variable(var, needs_copy=True, name=None): var = maybe_default_fill_value(var) var = maybe_encode_bools(var) var = ensure_dtype_not_object(var, name=name) + + for attr_name in CF_RELATED_DATA: + pop_to(var.encoding, var.attrs, attr_name) return var @@ -354,7 +374,7 @@ def decode_cf_variable( data = BoolTypeArray(data) if not is_duck_dask_array(data): - data = indexing.LazilyOuterIndexedArray(data) + data = indexing.LazilyIndexedArray(data) return Variable(dimensions, data, attributes, encoding=encoding) @@ -499,7 +519,7 @@ def stackable(dim): use_cftime=use_cftime, decode_timedelta=decode_timedelta, ) - if decode_coords: + if decode_coords in [True, "coordinates", "all"]: var_attrs = new_vars[k].attrs if "coordinates" in var_attrs: coord_str = var_attrs["coordinates"] @@ -509,6 +529,38 @@ def stackable(dim): del var_attrs["coordinates"] coord_names.update(var_coord_names) + if decode_coords == "all": + for attr_name in CF_RELATED_DATA: + if attr_name in var_attrs: + attr_val = var_attrs[attr_name] + if attr_name not in CF_RELATED_DATA_NEEDS_PARSING: + var_names = attr_val.split() + else: + roles_and_names = [ + role_or_name + for part in attr_val.split(":") + for role_or_name in part.split() + ] + if len(roles_and_names) % 2 == 1: + warnings.warn( + f"Attribute {attr_name:s} malformed", stacklevel=5 + ) + var_names = roles_and_names[1::2] + if all(var_name in variables for var_name in var_names): + new_vars[k].encoding[attr_name] = attr_val + coord_names.update(var_names) + else: + referenced_vars_not_in_variables = [ + proj_name + for proj_name in var_names + if proj_name not in variables + ] + warnings.warn( + f"Variable(s) referenced in {attr_name:s} not in variables: {referenced_vars_not_in_variables!s}", + stacklevel=5, + ) + del var_attrs[attr_name] + if decode_coords and "coordinates" in attributes: attributes = dict(attributes) coord_names.update(attributes.pop("coordinates").split()) @@ -542,9 +594,14 @@ def decode_cf( decode_times : bool, optional Decode cf times (e.g., integers since "hours since 2000-01-01") to np.datetime64. - decode_coords : bool, optional - Use the 'coordinates' attribute on variable (or the dataset itself) to - identify coordinates. + decode_coords : bool or {"coordinates", "all"}, optional + Controls which variables are set as coordinate variables: + + - "coordinates" or True: Set variables referred to in the + ``'coordinates'`` attribute of the datasets or individual variables + as coordinate variables. + - "all": Set variables referred to in ``'grid_mapping'``, ``'bounds'`` and + other attributes as coordinate variables. drop_variables : str or iterable, optional A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or @@ -624,7 +681,7 @@ def cf_decoder( concat_characters : bool Should character arrays be concatenated to strings, for example: ["h", "e", "l", "l", "o"] -> "hello" - mask_and_scale: bool + mask_and_scale : bool Lazily scale (using scale_factor and add_offset) and mask (using _FillValue). decode_times : bool @@ -637,7 +694,7 @@ def cf_decoder( decoded_attributes : dict A dictionary mapping from attribute name to values. - See also + See Also -------- decode_cf_variable """ @@ -664,6 +721,7 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names): global_coordinates = non_dim_coord_names.copy() variable_coordinates = defaultdict(set) + not_technically_coordinates = set() for coord_name in non_dim_coord_names: target_dims = variables[coord_name].dims for k, v in variables.items(): @@ -674,6 +732,13 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names): ): variable_coordinates[k].add(coord_name) + if any( + attr_name in v.encoding and coord_name in v.encoding.get(attr_name) + for attr_name in CF_RELATED_DATA + ): + not_technically_coordinates.add(coord_name) + global_coordinates.discard(coord_name) + variables = {k: v.copy(deep=False) for k, v in variables.items()} # keep track of variable names written to file under the "coordinates" attributes @@ -686,12 +751,30 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names): f"'coordinates' found in both attrs and encoding for variable {name!r}." ) + # if coordinates set to None, don't write coordinates attribute + if ( + "coordinates" in attrs + and attrs.get("coordinates") is None + or "coordinates" in encoding + and encoding.get("coordinates") is None + ): + # make sure "coordinates" is removed from attrs/encoding + attrs.pop("coordinates", None) + encoding.pop("coordinates", None) + continue + # this will copy coordinates from encoding to attrs if "coordinates" in attrs # after the next line, "coordinates" is never in encoding # we get support for attrs["coordinates"] for free. coords_str = pop_to(encoding, attrs, "coordinates") if not coords_str and variable_coordinates[name]: - attrs["coordinates"] = " ".join(map(str, variable_coordinates[name])) + coordinates_text = " ".join( + str(coord_name) + for coord_name in variable_coordinates[name] + if coord_name not in not_technically_coordinates + ) + if coordinates_text: + attrs["coordinates"] = coordinates_text if "coordinates" in attrs: written_coords.update(attrs["coordinates"].split()) @@ -747,7 +830,6 @@ def cf_encoder(variables, attributes): This includes masking, scaling, character array handling, and CF-time encoding. - Parameters ---------- variables : dict @@ -762,7 +844,7 @@ def cf_encoder(variables, attributes): encoded_attributes : dict A dictionary mapping from attribute name to value - See also + See Also -------- decode_cf_variable, encode_cf_variable """ diff --git a/xarray/core/_typed_ops.py b/xarray/core/_typed_ops.py new file mode 100644 index 00000000000..d1e68a6fc0d --- /dev/null +++ b/xarray/core/_typed_ops.py @@ -0,0 +1,800 @@ +"""Mixin classes with arithmetic operators.""" +# This file was generated using xarray.util.generate_ops. Do not edit manually. + +import operator + +from . import nputils, ops + + +class DatasetOpsMixin: + __slots__ = () + + def _binary_op(self, other, f, reflexive=False): + raise NotImplementedError + + def __add__(self, other): + return self._binary_op(other, operator.add) + + def __sub__(self, other): + return self._binary_op(other, operator.sub) + + def __mul__(self, other): + return self._binary_op(other, operator.mul) + + def __pow__(self, other): + return self._binary_op(other, operator.pow) + + def __truediv__(self, other): + return self._binary_op(other, operator.truediv) + + def __floordiv__(self, other): + return self._binary_op(other, operator.floordiv) + + def __mod__(self, other): + return self._binary_op(other, operator.mod) + + def __and__(self, other): + return self._binary_op(other, operator.and_) + + def __xor__(self, other): + return self._binary_op(other, operator.xor) + + def __or__(self, other): + return self._binary_op(other, operator.or_) + + def __lt__(self, other): + return self._binary_op(other, operator.lt) + + def __le__(self, other): + return self._binary_op(other, operator.le) + + def __gt__(self, other): + return self._binary_op(other, operator.gt) + + def __ge__(self, other): + return self._binary_op(other, operator.ge) + + def __eq__(self, other): + return self._binary_op(other, nputils.array_eq) + + def __ne__(self, other): + return self._binary_op(other, nputils.array_ne) + + def __radd__(self, other): + return self._binary_op(other, operator.add, reflexive=True) + + def __rsub__(self, other): + return self._binary_op(other, operator.sub, reflexive=True) + + def __rmul__(self, other): + return self._binary_op(other, operator.mul, reflexive=True) + + def __rpow__(self, other): + return self._binary_op(other, operator.pow, reflexive=True) + + def __rtruediv__(self, other): + return self._binary_op(other, operator.truediv, reflexive=True) + + def __rfloordiv__(self, other): + return self._binary_op(other, operator.floordiv, reflexive=True) + + def __rmod__(self, other): + return self._binary_op(other, operator.mod, reflexive=True) + + def __rand__(self, other): + return self._binary_op(other, operator.and_, reflexive=True) + + def __rxor__(self, other): + return self._binary_op(other, operator.xor, reflexive=True) + + def __ror__(self, other): + return self._binary_op(other, operator.or_, reflexive=True) + + def _inplace_binary_op(self, other, f): + raise NotImplementedError + + def __iadd__(self, other): + return self._inplace_binary_op(other, operator.iadd) + + def __isub__(self, other): + return self._inplace_binary_op(other, operator.isub) + + def __imul__(self, other): + return self._inplace_binary_op(other, operator.imul) + + def __ipow__(self, other): + return self._inplace_binary_op(other, operator.ipow) + + def __itruediv__(self, other): + return self._inplace_binary_op(other, operator.itruediv) + + def __ifloordiv__(self, other): + return self._inplace_binary_op(other, operator.ifloordiv) + + def __imod__(self, other): + return self._inplace_binary_op(other, operator.imod) + + def __iand__(self, other): + return self._inplace_binary_op(other, operator.iand) + + def __ixor__(self, other): + return self._inplace_binary_op(other, operator.ixor) + + def __ior__(self, other): + return self._inplace_binary_op(other, operator.ior) + + def _unary_op(self, f, *args, **kwargs): + raise NotImplementedError + + def __neg__(self): + return self._unary_op(operator.neg) + + def __pos__(self): + return self._unary_op(operator.pos) + + def __abs__(self): + return self._unary_op(operator.abs) + + def __invert__(self): + return self._unary_op(operator.invert) + + def round(self, *args, **kwargs): + return self._unary_op(ops.round_, *args, **kwargs) + + def argsort(self, *args, **kwargs): + return self._unary_op(ops.argsort, *args, **kwargs) + + def conj(self, *args, **kwargs): + return self._unary_op(ops.conj, *args, **kwargs) + + def conjugate(self, *args, **kwargs): + return self._unary_op(ops.conjugate, *args, **kwargs) + + __add__.__doc__ = operator.add.__doc__ + __sub__.__doc__ = operator.sub.__doc__ + __mul__.__doc__ = operator.mul.__doc__ + __pow__.__doc__ = operator.pow.__doc__ + __truediv__.__doc__ = operator.truediv.__doc__ + __floordiv__.__doc__ = operator.floordiv.__doc__ + __mod__.__doc__ = operator.mod.__doc__ + __and__.__doc__ = operator.and_.__doc__ + __xor__.__doc__ = operator.xor.__doc__ + __or__.__doc__ = operator.or_.__doc__ + __lt__.__doc__ = operator.lt.__doc__ + __le__.__doc__ = operator.le.__doc__ + __gt__.__doc__ = operator.gt.__doc__ + __ge__.__doc__ = operator.ge.__doc__ + __eq__.__doc__ = nputils.array_eq.__doc__ + __ne__.__doc__ = nputils.array_ne.__doc__ + __radd__.__doc__ = operator.add.__doc__ + __rsub__.__doc__ = operator.sub.__doc__ + __rmul__.__doc__ = operator.mul.__doc__ + __rpow__.__doc__ = operator.pow.__doc__ + __rtruediv__.__doc__ = operator.truediv.__doc__ + __rfloordiv__.__doc__ = operator.floordiv.__doc__ + __rmod__.__doc__ = operator.mod.__doc__ + __rand__.__doc__ = operator.and_.__doc__ + __rxor__.__doc__ = operator.xor.__doc__ + __ror__.__doc__ = operator.or_.__doc__ + __iadd__.__doc__ = operator.iadd.__doc__ + __isub__.__doc__ = operator.isub.__doc__ + __imul__.__doc__ = operator.imul.__doc__ + __ipow__.__doc__ = operator.ipow.__doc__ + __itruediv__.__doc__ = operator.itruediv.__doc__ + __ifloordiv__.__doc__ = operator.ifloordiv.__doc__ + __imod__.__doc__ = operator.imod.__doc__ + __iand__.__doc__ = operator.iand.__doc__ + __ixor__.__doc__ = operator.ixor.__doc__ + __ior__.__doc__ = operator.ior.__doc__ + __neg__.__doc__ = operator.neg.__doc__ + __pos__.__doc__ = operator.pos.__doc__ + __abs__.__doc__ = operator.abs.__doc__ + __invert__.__doc__ = operator.invert.__doc__ + round.__doc__ = ops.round_.__doc__ + argsort.__doc__ = ops.argsort.__doc__ + conj.__doc__ = ops.conj.__doc__ + conjugate.__doc__ = ops.conjugate.__doc__ + + +class DataArrayOpsMixin: + __slots__ = () + + def _binary_op(self, other, f, reflexive=False): + raise NotImplementedError + + def __add__(self, other): + return self._binary_op(other, operator.add) + + def __sub__(self, other): + return self._binary_op(other, operator.sub) + + def __mul__(self, other): + return self._binary_op(other, operator.mul) + + def __pow__(self, other): + return self._binary_op(other, operator.pow) + + def __truediv__(self, other): + return self._binary_op(other, operator.truediv) + + def __floordiv__(self, other): + return self._binary_op(other, operator.floordiv) + + def __mod__(self, other): + return self._binary_op(other, operator.mod) + + def __and__(self, other): + return self._binary_op(other, operator.and_) + + def __xor__(self, other): + return self._binary_op(other, operator.xor) + + def __or__(self, other): + return self._binary_op(other, operator.or_) + + def __lt__(self, other): + return self._binary_op(other, operator.lt) + + def __le__(self, other): + return self._binary_op(other, operator.le) + + def __gt__(self, other): + return self._binary_op(other, operator.gt) + + def __ge__(self, other): + return self._binary_op(other, operator.ge) + + def __eq__(self, other): + return self._binary_op(other, nputils.array_eq) + + def __ne__(self, other): + return self._binary_op(other, nputils.array_ne) + + def __radd__(self, other): + return self._binary_op(other, operator.add, reflexive=True) + + def __rsub__(self, other): + return self._binary_op(other, operator.sub, reflexive=True) + + def __rmul__(self, other): + return self._binary_op(other, operator.mul, reflexive=True) + + def __rpow__(self, other): + return self._binary_op(other, operator.pow, reflexive=True) + + def __rtruediv__(self, other): + return self._binary_op(other, operator.truediv, reflexive=True) + + def __rfloordiv__(self, other): + return self._binary_op(other, operator.floordiv, reflexive=True) + + def __rmod__(self, other): + return self._binary_op(other, operator.mod, reflexive=True) + + def __rand__(self, other): + return self._binary_op(other, operator.and_, reflexive=True) + + def __rxor__(self, other): + return self._binary_op(other, operator.xor, reflexive=True) + + def __ror__(self, other): + return self._binary_op(other, operator.or_, reflexive=True) + + def _inplace_binary_op(self, other, f): + raise NotImplementedError + + def __iadd__(self, other): + return self._inplace_binary_op(other, operator.iadd) + + def __isub__(self, other): + return self._inplace_binary_op(other, operator.isub) + + def __imul__(self, other): + return self._inplace_binary_op(other, operator.imul) + + def __ipow__(self, other): + return self._inplace_binary_op(other, operator.ipow) + + def __itruediv__(self, other): + return self._inplace_binary_op(other, operator.itruediv) + + def __ifloordiv__(self, other): + return self._inplace_binary_op(other, operator.ifloordiv) + + def __imod__(self, other): + return self._inplace_binary_op(other, operator.imod) + + def __iand__(self, other): + return self._inplace_binary_op(other, operator.iand) + + def __ixor__(self, other): + return self._inplace_binary_op(other, operator.ixor) + + def __ior__(self, other): + return self._inplace_binary_op(other, operator.ior) + + def _unary_op(self, f, *args, **kwargs): + raise NotImplementedError + + def __neg__(self): + return self._unary_op(operator.neg) + + def __pos__(self): + return self._unary_op(operator.pos) + + def __abs__(self): + return self._unary_op(operator.abs) + + def __invert__(self): + return self._unary_op(operator.invert) + + def round(self, *args, **kwargs): + return self._unary_op(ops.round_, *args, **kwargs) + + def argsort(self, *args, **kwargs): + return self._unary_op(ops.argsort, *args, **kwargs) + + def conj(self, *args, **kwargs): + return self._unary_op(ops.conj, *args, **kwargs) + + def conjugate(self, *args, **kwargs): + return self._unary_op(ops.conjugate, *args, **kwargs) + + __add__.__doc__ = operator.add.__doc__ + __sub__.__doc__ = operator.sub.__doc__ + __mul__.__doc__ = operator.mul.__doc__ + __pow__.__doc__ = operator.pow.__doc__ + __truediv__.__doc__ = operator.truediv.__doc__ + __floordiv__.__doc__ = operator.floordiv.__doc__ + __mod__.__doc__ = operator.mod.__doc__ + __and__.__doc__ = operator.and_.__doc__ + __xor__.__doc__ = operator.xor.__doc__ + __or__.__doc__ = operator.or_.__doc__ + __lt__.__doc__ = operator.lt.__doc__ + __le__.__doc__ = operator.le.__doc__ + __gt__.__doc__ = operator.gt.__doc__ + __ge__.__doc__ = operator.ge.__doc__ + __eq__.__doc__ = nputils.array_eq.__doc__ + __ne__.__doc__ = nputils.array_ne.__doc__ + __radd__.__doc__ = operator.add.__doc__ + __rsub__.__doc__ = operator.sub.__doc__ + __rmul__.__doc__ = operator.mul.__doc__ + __rpow__.__doc__ = operator.pow.__doc__ + __rtruediv__.__doc__ = operator.truediv.__doc__ + __rfloordiv__.__doc__ = operator.floordiv.__doc__ + __rmod__.__doc__ = operator.mod.__doc__ + __rand__.__doc__ = operator.and_.__doc__ + __rxor__.__doc__ = operator.xor.__doc__ + __ror__.__doc__ = operator.or_.__doc__ + __iadd__.__doc__ = operator.iadd.__doc__ + __isub__.__doc__ = operator.isub.__doc__ + __imul__.__doc__ = operator.imul.__doc__ + __ipow__.__doc__ = operator.ipow.__doc__ + __itruediv__.__doc__ = operator.itruediv.__doc__ + __ifloordiv__.__doc__ = operator.ifloordiv.__doc__ + __imod__.__doc__ = operator.imod.__doc__ + __iand__.__doc__ = operator.iand.__doc__ + __ixor__.__doc__ = operator.ixor.__doc__ + __ior__.__doc__ = operator.ior.__doc__ + __neg__.__doc__ = operator.neg.__doc__ + __pos__.__doc__ = operator.pos.__doc__ + __abs__.__doc__ = operator.abs.__doc__ + __invert__.__doc__ = operator.invert.__doc__ + round.__doc__ = ops.round_.__doc__ + argsort.__doc__ = ops.argsort.__doc__ + conj.__doc__ = ops.conj.__doc__ + conjugate.__doc__ = ops.conjugate.__doc__ + + +class VariableOpsMixin: + __slots__ = () + + def _binary_op(self, other, f, reflexive=False): + raise NotImplementedError + + def __add__(self, other): + return self._binary_op(other, operator.add) + + def __sub__(self, other): + return self._binary_op(other, operator.sub) + + def __mul__(self, other): + return self._binary_op(other, operator.mul) + + def __pow__(self, other): + return self._binary_op(other, operator.pow) + + def __truediv__(self, other): + return self._binary_op(other, operator.truediv) + + def __floordiv__(self, other): + return self._binary_op(other, operator.floordiv) + + def __mod__(self, other): + return self._binary_op(other, operator.mod) + + def __and__(self, other): + return self._binary_op(other, operator.and_) + + def __xor__(self, other): + return self._binary_op(other, operator.xor) + + def __or__(self, other): + return self._binary_op(other, operator.or_) + + def __lt__(self, other): + return self._binary_op(other, operator.lt) + + def __le__(self, other): + return self._binary_op(other, operator.le) + + def __gt__(self, other): + return self._binary_op(other, operator.gt) + + def __ge__(self, other): + return self._binary_op(other, operator.ge) + + def __eq__(self, other): + return self._binary_op(other, nputils.array_eq) + + def __ne__(self, other): + return self._binary_op(other, nputils.array_ne) + + def __radd__(self, other): + return self._binary_op(other, operator.add, reflexive=True) + + def __rsub__(self, other): + return self._binary_op(other, operator.sub, reflexive=True) + + def __rmul__(self, other): + return self._binary_op(other, operator.mul, reflexive=True) + + def __rpow__(self, other): + return self._binary_op(other, operator.pow, reflexive=True) + + def __rtruediv__(self, other): + return self._binary_op(other, operator.truediv, reflexive=True) + + def __rfloordiv__(self, other): + return self._binary_op(other, operator.floordiv, reflexive=True) + + def __rmod__(self, other): + return self._binary_op(other, operator.mod, reflexive=True) + + def __rand__(self, other): + return self._binary_op(other, operator.and_, reflexive=True) + + def __rxor__(self, other): + return self._binary_op(other, operator.xor, reflexive=True) + + def __ror__(self, other): + return self._binary_op(other, operator.or_, reflexive=True) + + def _inplace_binary_op(self, other, f): + raise NotImplementedError + + def __iadd__(self, other): + return self._inplace_binary_op(other, operator.iadd) + + def __isub__(self, other): + return self._inplace_binary_op(other, operator.isub) + + def __imul__(self, other): + return self._inplace_binary_op(other, operator.imul) + + def __ipow__(self, other): + return self._inplace_binary_op(other, operator.ipow) + + def __itruediv__(self, other): + return self._inplace_binary_op(other, operator.itruediv) + + def __ifloordiv__(self, other): + return self._inplace_binary_op(other, operator.ifloordiv) + + def __imod__(self, other): + return self._inplace_binary_op(other, operator.imod) + + def __iand__(self, other): + return self._inplace_binary_op(other, operator.iand) + + def __ixor__(self, other): + return self._inplace_binary_op(other, operator.ixor) + + def __ior__(self, other): + return self._inplace_binary_op(other, operator.ior) + + def _unary_op(self, f, *args, **kwargs): + raise NotImplementedError + + def __neg__(self): + return self._unary_op(operator.neg) + + def __pos__(self): + return self._unary_op(operator.pos) + + def __abs__(self): + return self._unary_op(operator.abs) + + def __invert__(self): + return self._unary_op(operator.invert) + + def round(self, *args, **kwargs): + return self._unary_op(ops.round_, *args, **kwargs) + + def argsort(self, *args, **kwargs): + return self._unary_op(ops.argsort, *args, **kwargs) + + def conj(self, *args, **kwargs): + return self._unary_op(ops.conj, *args, **kwargs) + + def conjugate(self, *args, **kwargs): + return self._unary_op(ops.conjugate, *args, **kwargs) + + __add__.__doc__ = operator.add.__doc__ + __sub__.__doc__ = operator.sub.__doc__ + __mul__.__doc__ = operator.mul.__doc__ + __pow__.__doc__ = operator.pow.__doc__ + __truediv__.__doc__ = operator.truediv.__doc__ + __floordiv__.__doc__ = operator.floordiv.__doc__ + __mod__.__doc__ = operator.mod.__doc__ + __and__.__doc__ = operator.and_.__doc__ + __xor__.__doc__ = operator.xor.__doc__ + __or__.__doc__ = operator.or_.__doc__ + __lt__.__doc__ = operator.lt.__doc__ + __le__.__doc__ = operator.le.__doc__ + __gt__.__doc__ = operator.gt.__doc__ + __ge__.__doc__ = operator.ge.__doc__ + __eq__.__doc__ = nputils.array_eq.__doc__ + __ne__.__doc__ = nputils.array_ne.__doc__ + __radd__.__doc__ = operator.add.__doc__ + __rsub__.__doc__ = operator.sub.__doc__ + __rmul__.__doc__ = operator.mul.__doc__ + __rpow__.__doc__ = operator.pow.__doc__ + __rtruediv__.__doc__ = operator.truediv.__doc__ + __rfloordiv__.__doc__ = operator.floordiv.__doc__ + __rmod__.__doc__ = operator.mod.__doc__ + __rand__.__doc__ = operator.and_.__doc__ + __rxor__.__doc__ = operator.xor.__doc__ + __ror__.__doc__ = operator.or_.__doc__ + __iadd__.__doc__ = operator.iadd.__doc__ + __isub__.__doc__ = operator.isub.__doc__ + __imul__.__doc__ = operator.imul.__doc__ + __ipow__.__doc__ = operator.ipow.__doc__ + __itruediv__.__doc__ = operator.itruediv.__doc__ + __ifloordiv__.__doc__ = operator.ifloordiv.__doc__ + __imod__.__doc__ = operator.imod.__doc__ + __iand__.__doc__ = operator.iand.__doc__ + __ixor__.__doc__ = operator.ixor.__doc__ + __ior__.__doc__ = operator.ior.__doc__ + __neg__.__doc__ = operator.neg.__doc__ + __pos__.__doc__ = operator.pos.__doc__ + __abs__.__doc__ = operator.abs.__doc__ + __invert__.__doc__ = operator.invert.__doc__ + round.__doc__ = ops.round_.__doc__ + argsort.__doc__ = ops.argsort.__doc__ + conj.__doc__ = ops.conj.__doc__ + conjugate.__doc__ = ops.conjugate.__doc__ + + +class DatasetGroupByOpsMixin: + __slots__ = () + + def _binary_op(self, other, f, reflexive=False): + raise NotImplementedError + + def __add__(self, other): + return self._binary_op(other, operator.add) + + def __sub__(self, other): + return self._binary_op(other, operator.sub) + + def __mul__(self, other): + return self._binary_op(other, operator.mul) + + def __pow__(self, other): + return self._binary_op(other, operator.pow) + + def __truediv__(self, other): + return self._binary_op(other, operator.truediv) + + def __floordiv__(self, other): + return self._binary_op(other, operator.floordiv) + + def __mod__(self, other): + return self._binary_op(other, operator.mod) + + def __and__(self, other): + return self._binary_op(other, operator.and_) + + def __xor__(self, other): + return self._binary_op(other, operator.xor) + + def __or__(self, other): + return self._binary_op(other, operator.or_) + + def __lt__(self, other): + return self._binary_op(other, operator.lt) + + def __le__(self, other): + return self._binary_op(other, operator.le) + + def __gt__(self, other): + return self._binary_op(other, operator.gt) + + def __ge__(self, other): + return self._binary_op(other, operator.ge) + + def __eq__(self, other): + return self._binary_op(other, nputils.array_eq) + + def __ne__(self, other): + return self._binary_op(other, nputils.array_ne) + + def __radd__(self, other): + return self._binary_op(other, operator.add, reflexive=True) + + def __rsub__(self, other): + return self._binary_op(other, operator.sub, reflexive=True) + + def __rmul__(self, other): + return self._binary_op(other, operator.mul, reflexive=True) + + def __rpow__(self, other): + return self._binary_op(other, operator.pow, reflexive=True) + + def __rtruediv__(self, other): + return self._binary_op(other, operator.truediv, reflexive=True) + + def __rfloordiv__(self, other): + return self._binary_op(other, operator.floordiv, reflexive=True) + + def __rmod__(self, other): + return self._binary_op(other, operator.mod, reflexive=True) + + def __rand__(self, other): + return self._binary_op(other, operator.and_, reflexive=True) + + def __rxor__(self, other): + return self._binary_op(other, operator.xor, reflexive=True) + + def __ror__(self, other): + return self._binary_op(other, operator.or_, reflexive=True) + + __add__.__doc__ = operator.add.__doc__ + __sub__.__doc__ = operator.sub.__doc__ + __mul__.__doc__ = operator.mul.__doc__ + __pow__.__doc__ = operator.pow.__doc__ + __truediv__.__doc__ = operator.truediv.__doc__ + __floordiv__.__doc__ = operator.floordiv.__doc__ + __mod__.__doc__ = operator.mod.__doc__ + __and__.__doc__ = operator.and_.__doc__ + __xor__.__doc__ = operator.xor.__doc__ + __or__.__doc__ = operator.or_.__doc__ + __lt__.__doc__ = operator.lt.__doc__ + __le__.__doc__ = operator.le.__doc__ + __gt__.__doc__ = operator.gt.__doc__ + __ge__.__doc__ = operator.ge.__doc__ + __eq__.__doc__ = nputils.array_eq.__doc__ + __ne__.__doc__ = nputils.array_ne.__doc__ + __radd__.__doc__ = operator.add.__doc__ + __rsub__.__doc__ = operator.sub.__doc__ + __rmul__.__doc__ = operator.mul.__doc__ + __rpow__.__doc__ = operator.pow.__doc__ + __rtruediv__.__doc__ = operator.truediv.__doc__ + __rfloordiv__.__doc__ = operator.floordiv.__doc__ + __rmod__.__doc__ = operator.mod.__doc__ + __rand__.__doc__ = operator.and_.__doc__ + __rxor__.__doc__ = operator.xor.__doc__ + __ror__.__doc__ = operator.or_.__doc__ + + +class DataArrayGroupByOpsMixin: + __slots__ = () + + def _binary_op(self, other, f, reflexive=False): + raise NotImplementedError + + def __add__(self, other): + return self._binary_op(other, operator.add) + + def __sub__(self, other): + return self._binary_op(other, operator.sub) + + def __mul__(self, other): + return self._binary_op(other, operator.mul) + + def __pow__(self, other): + return self._binary_op(other, operator.pow) + + def __truediv__(self, other): + return self._binary_op(other, operator.truediv) + + def __floordiv__(self, other): + return self._binary_op(other, operator.floordiv) + + def __mod__(self, other): + return self._binary_op(other, operator.mod) + + def __and__(self, other): + return self._binary_op(other, operator.and_) + + def __xor__(self, other): + return self._binary_op(other, operator.xor) + + def __or__(self, other): + return self._binary_op(other, operator.or_) + + def __lt__(self, other): + return self._binary_op(other, operator.lt) + + def __le__(self, other): + return self._binary_op(other, operator.le) + + def __gt__(self, other): + return self._binary_op(other, operator.gt) + + def __ge__(self, other): + return self._binary_op(other, operator.ge) + + def __eq__(self, other): + return self._binary_op(other, nputils.array_eq) + + def __ne__(self, other): + return self._binary_op(other, nputils.array_ne) + + def __radd__(self, other): + return self._binary_op(other, operator.add, reflexive=True) + + def __rsub__(self, other): + return self._binary_op(other, operator.sub, reflexive=True) + + def __rmul__(self, other): + return self._binary_op(other, operator.mul, reflexive=True) + + def __rpow__(self, other): + return self._binary_op(other, operator.pow, reflexive=True) + + def __rtruediv__(self, other): + return self._binary_op(other, operator.truediv, reflexive=True) + + def __rfloordiv__(self, other): + return self._binary_op(other, operator.floordiv, reflexive=True) + + def __rmod__(self, other): + return self._binary_op(other, operator.mod, reflexive=True) + + def __rand__(self, other): + return self._binary_op(other, operator.and_, reflexive=True) + + def __rxor__(self, other): + return self._binary_op(other, operator.xor, reflexive=True) + + def __ror__(self, other): + return self._binary_op(other, operator.or_, reflexive=True) + + __add__.__doc__ = operator.add.__doc__ + __sub__.__doc__ = operator.sub.__doc__ + __mul__.__doc__ = operator.mul.__doc__ + __pow__.__doc__ = operator.pow.__doc__ + __truediv__.__doc__ = operator.truediv.__doc__ + __floordiv__.__doc__ = operator.floordiv.__doc__ + __mod__.__doc__ = operator.mod.__doc__ + __and__.__doc__ = operator.and_.__doc__ + __xor__.__doc__ = operator.xor.__doc__ + __or__.__doc__ = operator.or_.__doc__ + __lt__.__doc__ = operator.lt.__doc__ + __le__.__doc__ = operator.le.__doc__ + __gt__.__doc__ = operator.gt.__doc__ + __ge__.__doc__ = operator.ge.__doc__ + __eq__.__doc__ = nputils.array_eq.__doc__ + __ne__.__doc__ = nputils.array_ne.__doc__ + __radd__.__doc__ = operator.add.__doc__ + __rsub__.__doc__ = operator.sub.__doc__ + __rmul__.__doc__ = operator.mul.__doc__ + __rpow__.__doc__ = operator.pow.__doc__ + __rtruediv__.__doc__ = operator.truediv.__doc__ + __rfloordiv__.__doc__ = operator.floordiv.__doc__ + __rmod__.__doc__ = operator.mod.__doc__ + __rand__.__doc__ = operator.and_.__doc__ + __rxor__.__doc__ = operator.xor.__doc__ + __ror__.__doc__ = operator.or_.__doc__ diff --git a/xarray/core/_typed_ops.pyi b/xarray/core/_typed_ops.pyi new file mode 100644 index 00000000000..4a6c2dc7b4e --- /dev/null +++ b/xarray/core/_typed_ops.pyi @@ -0,0 +1,728 @@ +"""Stub file for mixin classes with arithmetic operators.""" +# This file was generated using xarray.util.generate_ops. Do not edit manually. + +from typing import NoReturn, TypeVar, Union, overload + +import numpy as np + +from .dataarray import DataArray +from .dataset import Dataset +from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy +from .npcompat import ArrayLike +from .variable import Variable + +try: + from dask.array import Array as DaskArray +except ImportError: + DaskArray = np.ndarray + +# DatasetOpsMixin etc. are parent classes of Dataset etc. +T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin") +T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin") +T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin") + +ScalarOrArray = Union[ArrayLike, np.generic, np.ndarray, DaskArray] +DsCompatible = Union[Dataset, DataArray, Variable, GroupBy, ScalarOrArray] +DaCompatible = Union[DataArray, Variable, DataArrayGroupBy, ScalarOrArray] +VarCompatible = Union[Variable, ScalarOrArray] +GroupByIncompatible = Union[Variable, GroupBy] + +class DatasetOpsMixin: + __slots__ = () + def _binary_op(self, other, f, reflexive=...): ... + def __add__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __sub__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __mul__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __pow__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __truediv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __floordiv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __mod__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __and__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __xor__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __or__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __lt__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __le__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __gt__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __ge__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __eq__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... # type: ignore[override] + def __ne__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... # type: ignore[override] + def __radd__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __rsub__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __rmul__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __rpow__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __rtruediv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __rfloordiv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __rmod__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __rand__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __rxor__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def __ror__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... + def _inplace_binary_op(self, other, f): ... + def _unary_op(self, f, *args, **kwargs): ... + def __neg__(self: T_Dataset) -> T_Dataset: ... + def __pos__(self: T_Dataset) -> T_Dataset: ... + def __abs__(self: T_Dataset) -> T_Dataset: ... + def __invert__(self: T_Dataset) -> T_Dataset: ... + def round(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... + def argsort(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... + def conj(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... + def conjugate(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... + +class DataArrayOpsMixin: + __slots__ = () + def _binary_op(self, other, f, reflexive=...): ... + @overload + def __add__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __add__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __add__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __sub__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __sub__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __sub__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __mul__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __mul__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __mul__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __pow__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __pow__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __pow__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __truediv__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __truediv__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __truediv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __floordiv__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __floordiv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __mod__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __mod__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __mod__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __and__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __and__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __and__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __xor__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __xor__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __xor__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __or__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __or__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __or__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __lt__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __lt__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __lt__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __le__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __le__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __le__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __gt__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __gt__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __gt__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __ge__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __ge__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __ge__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload # type: ignore[override] + def __eq__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __eq__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __eq__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload # type: ignore[override] + def __ne__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __ne__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __ne__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __radd__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __radd__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __radd__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __rsub__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rsub__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __rsub__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __rmul__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rmul__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __rmul__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __rpow__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rpow__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __rpow__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rtruediv__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __rtruediv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rfloordiv__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __rfloordiv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __rmod__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rmod__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __rmod__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __rand__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rand__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __rand__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __rxor__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rxor__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __rxor__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + @overload + def __ror__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __ror__(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def __ror__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... + def _inplace_binary_op(self, other, f): ... + def _unary_op(self, f, *args, **kwargs): ... + def __neg__(self: T_DataArray) -> T_DataArray: ... + def __pos__(self: T_DataArray) -> T_DataArray: ... + def __abs__(self: T_DataArray) -> T_DataArray: ... + def __invert__(self: T_DataArray) -> T_DataArray: ... + def round(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... + def argsort(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... + def conj(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... + def conjugate(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... + +class VariableOpsMixin: + __slots__ = () + def _binary_op(self, other, f, reflexive=...): ... + @overload + def __add__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __add__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __add__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __sub__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __sub__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __sub__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __mul__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __mul__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __mul__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __pow__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __pow__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __pow__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __truediv__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __truediv__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __truediv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __floordiv__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __floordiv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __mod__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __mod__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __mod__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __and__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __and__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __and__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __xor__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __xor__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __xor__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __or__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __or__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __or__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __lt__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __lt__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __lt__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __le__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __le__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __le__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __gt__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __gt__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __gt__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __ge__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __ge__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __ge__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload # type: ignore[override] + def __eq__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __eq__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __eq__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload # type: ignore[override] + def __ne__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __ne__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __ne__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __radd__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __radd__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __radd__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __rsub__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rsub__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __rsub__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __rmul__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rmul__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __rmul__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __rpow__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rpow__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __rpow__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rtruediv__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __rtruediv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rfloordiv__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __rfloordiv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __rmod__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rmod__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __rmod__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __rand__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rand__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __rand__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __rxor__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rxor__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __rxor__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + @overload + def __ror__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __ror__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __ror__(self: T_Variable, other: VarCompatible) -> T_Variable: ... + def _inplace_binary_op(self, other, f): ... + def _unary_op(self, f, *args, **kwargs): ... + def __neg__(self: T_Variable) -> T_Variable: ... + def __pos__(self: T_Variable) -> T_Variable: ... + def __abs__(self: T_Variable) -> T_Variable: ... + def __invert__(self: T_Variable) -> T_Variable: ... + def round(self: T_Variable, *args, **kwargs) -> T_Variable: ... + def argsort(self: T_Variable, *args, **kwargs) -> T_Variable: ... + def conj(self: T_Variable, *args, **kwargs) -> T_Variable: ... + def conjugate(self: T_Variable, *args, **kwargs) -> T_Variable: ... + +class DatasetGroupByOpsMixin: + __slots__ = () + def _binary_op(self, other, f, reflexive=...): ... + @overload + def __add__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __add__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __add__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __sub__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __sub__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __sub__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __mul__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __mul__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __mul__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __pow__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __pow__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __pow__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __truediv__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __truediv__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __truediv__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __floordiv__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __floordiv__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __mod__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __mod__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __mod__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __and__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __and__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __and__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __xor__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __xor__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __xor__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __or__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __or__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __or__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __lt__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __lt__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __lt__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __le__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __le__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __le__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __gt__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __gt__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __gt__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __ge__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __ge__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __ge__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload # type: ignore[override] + def __eq__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __eq__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __eq__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload # type: ignore[override] + def __ne__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __ne__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __ne__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __radd__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __radd__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __radd__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __rsub__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rsub__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __rsub__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __rmul__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rmul__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __rmul__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __rpow__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rpow__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __rpow__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rtruediv__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __rtruediv__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rfloordiv__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __rfloordiv__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __rmod__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rmod__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __rmod__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __rand__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rand__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __rand__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __rxor__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rxor__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __rxor__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __ror__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __ror__(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def __ror__(self, other: GroupByIncompatible) -> NoReturn: ... + +class DataArrayGroupByOpsMixin: + __slots__ = () + def _binary_op(self, other, f, reflexive=...): ... + @overload + def __add__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __add__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __add__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __sub__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __sub__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __sub__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __mul__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __mul__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __mul__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __pow__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __pow__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __pow__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __truediv__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __truediv__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __truediv__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __floordiv__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __floordiv__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __mod__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __mod__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __mod__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __and__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __and__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __and__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __xor__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __xor__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __xor__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __or__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __or__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __or__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __lt__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __lt__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __lt__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __le__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __le__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __le__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __gt__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __gt__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __gt__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __ge__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __ge__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __ge__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload # type: ignore[override] + def __eq__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __eq__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __eq__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload # type: ignore[override] + def __ne__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __ne__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __ne__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __radd__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __radd__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __radd__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __rsub__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rsub__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __rsub__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __rmul__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rmul__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __rmul__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __rpow__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rpow__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __rpow__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rtruediv__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __rtruediv__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rfloordiv__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __rfloordiv__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __rmod__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rmod__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __rmod__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __rand__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rand__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __rand__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __rxor__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __rxor__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __rxor__(self, other: GroupByIncompatible) -> NoReturn: ... + @overload + def __ror__(self, other: T_Dataset) -> T_Dataset: ... + @overload + def __ror__(self, other: T_DataArray) -> T_DataArray: ... + @overload + def __ror__(self, other: GroupByIncompatible) -> NoReturn: ... diff --git a/xarray/core/accessor_dt.py b/xarray/core/accessor_dt.py index 3fc682f8c32..0965d440fc7 100644 --- a/xarray/core/accessor_dt.py +++ b/xarray/core/accessor_dt.py @@ -9,6 +9,7 @@ is_np_datetime_like, is_np_timedelta_like, ) +from .npcompat import DTypeLike from .pycompat import is_duck_dask_array @@ -30,6 +31,10 @@ def _access_through_cftimeindex(values, name): if name == "season": months = values_as_cftimeindex.month field_values = _season_from_months(months) + elif name == "date": + raise AttributeError( + "'CFTimeIndex' object has no attribute `date`. Consider using the floor method instead, for instance: `.time.dt.floor('D')`." + ) else: field_values = getattr(values_as_cftimeindex, name) return field_values.reshape(values.shape) @@ -178,8 +183,9 @@ class Properties: def __init__(self, obj): self._obj = obj - def _tslib_field_accessor( # type: ignore - name: str, docstring: str = None, dtype: np.dtype = None + @staticmethod + def _tslib_field_accessor( + name: str, docstring: str = None, dtype: DTypeLike = None ): def f(self, dtype=dtype): if dtype is None: @@ -257,8 +263,6 @@ class DatetimeAccessor(Properties): Examples --------- - >>> import xarray as xr - >>> import pandas as pd >>> dates = pd.date_range(start="2000/01/01", freq="D", periods=10) >>> ts = xr.DataArray(dates, dims=("time")) >>> ts @@ -322,8 +326,8 @@ def strftime(self, date_format): def isocalendar(self): """Dataset containing ISO year, week number, and weekday. - Note - ---- + Notes + ----- The iso year and weekday differ from the nominal year and weekday. """ @@ -413,6 +417,10 @@ def weekofyear(self): "time", "Timestamps corresponding to datetimes", object ) + date = Properties._tslib_field_accessor( + "date", "Date corresponding to datetimes", object + ) + is_month_start = Properties._tslib_field_accessor( "is_month_start", "Indicates whether the date is the first day of the month.", @@ -449,8 +457,6 @@ class TimedeltaAccessor(Properties): Examples -------- - >>> import pandas as pd - >>> import xarray as xr >>> dates = pd.timedelta_range(start="1 day", freq="6H", periods=20) >>> ts = xr.DataArray(dates, dims=("time")) >>> ts diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index 02d8ca00bf9..e3c35d6e4b6 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -40,6 +40,20 @@ import codecs import re import textwrap +from functools import reduce +from operator import or_ as set_union +from typing import ( + Any, + Callable, + Hashable, + Mapping, + Optional, + Pattern, + Tuple, + Type, + Union, +) +from unicodedata import normalize import numpy as np @@ -57,12 +71,76 @@ _cpython_optimized_decoders = _cpython_optimized_encoders + ("utf-16", "utf-32") -def _is_str_like(x): - return isinstance(x, str) or isinstance(x, bytes) +def _contains_obj_type(*, pat: Any, checker: Any) -> bool: + """Determine if the object fits some rule or is array of objects that do so.""" + if isinstance(checker, type): + targtype = checker + checker = lambda x: isinstance(x, targtype) + + if checker(pat): + return True + + # If it is not an object array it can't contain compiled re + if getattr(pat, "dtype", "no") != np.object_: + return False + + return _apply_str_ufunc(func=checker, obj=pat).all() + + +def _contains_str_like(pat: Any) -> bool: + """Determine if the object is a str-like or array of str-like.""" + if isinstance(pat, (str, bytes)): + return True + + if not hasattr(pat, "dtype"): + return False + + return pat.dtype.kind in ["U", "S"] + + +def _contains_compiled_re(pat: Any) -> bool: + """Determine if the object is a compiled re or array of compiled re.""" + return _contains_obj_type(pat=pat, checker=re.Pattern) + + +def _contains_callable(pat: Any) -> bool: + """Determine if the object is a callable or array of callables.""" + return _contains_obj_type(pat=pat, checker=callable) + + +def _apply_str_ufunc( + *, + func: Callable, + obj: Any, + dtype: Union[str, np.dtype, Type] = None, + output_core_dims: Union[list, tuple] = ((),), + output_sizes: Mapping[Hashable, int] = None, + func_args: Tuple = (), + func_kwargs: Mapping = {}, +) -> Any: + # TODO handling of na values ? + if dtype is None: + dtype = obj.dtype + + dask_gufunc_kwargs = dict() + if output_sizes is not None: + dask_gufunc_kwargs["output_sizes"] = output_sizes + + return apply_ufunc( + func, + obj, + *func_args, + vectorize=True, + dask="parallelized", + output_dtypes=[dtype], + output_core_dims=output_core_dims, + dask_gufunc_kwargs=dask_gufunc_kwargs, + **func_kwargs, + ) class StringAccessor: - """Vectorized string functions for string-like arrays. + r"""Vectorized string functions for string-like arrays. Similar to pandas, fields can be accessed through the `.str` attribute for applicable DataArrays. @@ -73,6 +151,55 @@ class StringAccessor: array([4, 4, 2, 2, 5]) Dimensions without coordinates: dim_0 + It also implements ``+``, ``*``, and ``%``, which operate as elementwise + versions of the corresponding ``str`` methods. These will automatically + broadcast for array-like inputs. + + >>> da1 = xr.DataArray(["first", "second", "third"], dims=["X"]) + >>> da2 = xr.DataArray([1, 2, 3], dims=["Y"]) + >>> da1.str + da2 + + array([['first1', 'first2', 'first3'], + ['second1', 'second2', 'second3'], + ['third1', 'third2', 'third3']], dtype='>> da1 = xr.DataArray(["a", "b", "c", "d"], dims=["X"]) + >>> reps = xr.DataArray([3, 4], dims=["Y"]) + >>> da1.str * reps + + array([['aaa', 'aaaa'], + ['bbb', 'bbbb'], + ['ccc', 'cccc'], + ['ddd', 'dddd']], dtype='>> da1 = xr.DataArray(["%s_%s", "%s-%s", "%s|%s"], dims=["X"]) + >>> da2 = xr.DataArray([1, 2], dims=["Y"]) + >>> da3 = xr.DataArray([0.1, 0.2], dims=["Z"]) + >>> da1.str % (da2, da3) + + array([[['1_0.1', '1_0.2'], + ['2_0.1', '2_0.2']], + + [['1-0.1', '1-0.2'], + ['2-0.1', '2-0.2']], + + [['1|0.1', '1|0.2'], + ['2|0.1', '2|0.2']]], dtype='>> da1 = xr.DataArray(["%(a)s"], dims=["X"]) + >>> da2 = xr.DataArray([1, 2, 3], dims=["Y"]) + >>> da1 % {"a": da2} + + array(['\narray([1, 2, 3])\nDimensions without coordinates: Y'], + dtype=object) + Dimensions without coordinates: X """ __slots__ = ("_obj",) @@ -80,15 +207,81 @@ class StringAccessor: def __init__(self, obj): self._obj = obj - def _apply(self, f, dtype=None): - # TODO handling of na values ? - if dtype is None: - dtype = self._obj.dtype + def _stringify( + self, + invar: Any, + ) -> Union[str, bytes, Any]: + """ + Convert a string-like to the correct string/bytes type. - g = np.vectorize(f, otypes=[dtype]) - return apply_ufunc(g, self._obj, dask="parallelized", output_dtypes=[dtype]) + This is mostly here to tell mypy a pattern is a str/bytes not a re.Pattern. + """ + if hasattr(invar, "astype"): + return invar.astype(self._obj.dtype.kind) + else: + return self._obj.dtype.type(invar) + + def _apply( + self, + *, + func: Callable, + dtype: Union[str, np.dtype, Type] = None, + output_core_dims: Union[list, tuple] = ((),), + output_sizes: Mapping[Hashable, int] = None, + func_args: Tuple = (), + func_kwargs: Mapping = {}, + ) -> Any: + return _apply_str_ufunc( + obj=self._obj, + func=func, + dtype=dtype, + output_core_dims=output_core_dims, + output_sizes=output_sizes, + func_args=func_args, + func_kwargs=func_kwargs, + ) + + def _re_compile( + self, + *, + pat: Union[str, bytes, Pattern, Any], + flags: int = 0, + case: bool = None, + ) -> Union[Pattern, Any]: + is_compiled_re = isinstance(pat, re.Pattern) + + if is_compiled_re and flags != 0: + raise ValueError("Flags cannot be set when pat is a compiled regex.") + + if is_compiled_re and case is not None: + raise ValueError("Case cannot be set when pat is a compiled regex.") + + if is_compiled_re: + # no-op, needed to tell mypy this isn't a string + return re.compile(pat) + + if case is None: + case = True + + # The case is handled by the re flags internally. + # Add it to the flags if necessary. + if not case: + flags |= re.IGNORECASE + + if getattr(pat, "dtype", None) != np.object_: + pat = self._stringify(pat) + + def func(x): + return re.compile(x, flags=flags) - def len(self): + if isinstance(pat, np.ndarray): + # apply_ufunc doesn't work for numpy arrays with output object dtypes + func = np.vectorize(func) + return func(pat) + else: + return _apply_str_ufunc(func=func, obj=pat, dtype=np.object_) + + def len(self) -> Any: """ Compute the length of each string in the array. @@ -96,22 +289,58 @@ def len(self): ------- lengths array : array of int """ - return self._apply(len, dtype=int) + return self._apply(func=len, dtype=int) - def __getitem__(self, key): + def __getitem__( + self, + key: Union[int, slice], + ) -> Any: if isinstance(key, slice): return self.slice(start=key.start, stop=key.stop, step=key.step) else: return self.get(key) - def get(self, i, default=""): + def __add__( + self, + other: Any, + ) -> Any: + return self.cat(other, sep="") + + def __mul__( + self, + num: Union[int, Any], + ) -> Any: + return self.repeat(num) + + def __mod__( + self, + other: Any, + ) -> Any: + if isinstance(other, dict): + other = {key: self._stringify(val) for key, val in other.items()} + return self._apply(func=lambda x: x % other) + elif isinstance(other, tuple): + other = tuple(self._stringify(x) for x in other) + return self._apply(func=lambda x, *y: x % y, func_args=other) + else: + return self._apply(func=lambda x, y: x % y, func_args=(other,)) + + def get( + self, + i: Union[int, Any], + default: Union[str, bytes] = "", + ) -> Any: """ Extract character number `i` from each string in the array. + If `i` is array-like, they are broadcast against the array and + applied elementwise. + Parameters ---------- - i : int + i : int or array-like of int Position of element to extract. + If array-like, it is broadcast. default : optional Value for out-of-range index. If not specified (None) defaults to an empty string. @@ -120,76 +349,334 @@ def get(self, i, default=""): ------- items : array of object """ - s = slice(-1, None) if i == -1 else slice(i, i + 1) - def f(x): - item = x[s] + def f(x, iind): + islice = slice(-1, None) if iind == -1 else slice(iind, iind + 1) + item = x[islice] return item if item else default - return self._apply(f) + return self._apply(func=f, func_args=(i,)) - def slice(self, start=None, stop=None, step=None): + def slice( + self, + start: Union[int, Any] = None, + stop: Union[int, Any] = None, + step: Union[int, Any] = None, + ) -> Any: """ Slice substrings from each string in the array. + If `start`, `stop`, or 'step` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - start : int, optional + start : int or array-like of int, optional Start position for slice operation. - stop : int, optional + If array-like, it is broadcast. + stop : int or array-like of int, optional Stop position for slice operation. - step : int, optional + If array-like, it is broadcast. + step : int or array-like of int, optional Step size for slice operation. + If array-like, it is broadcast. Returns ------- sliced strings : same type as values """ - s = slice(start, stop, step) - f = lambda x: x[s] - return self._apply(f) + f = lambda x, istart, istop, istep: x[slice(istart, istop, istep)] + return self._apply(func=f, func_args=(start, stop, step)) - def slice_replace(self, start=None, stop=None, repl=""): + def slice_replace( + self, + start: Union[int, Any] = None, + stop: Union[int, Any] = None, + repl: Union[str, bytes, Any] = "", + ) -> Any: """ Replace a positional slice of a string with another value. + If `start`, `stop`, or 'repl` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - start : int, optional + start : int or array-like of int, optional Left index position to use for the slice. If not specified (None), the slice is unbounded on the left, i.e. slice from the start - of the string. - stop : int, optional + of the string. If array-like, it is broadcast. + stop : int or array-like of int, optional Right index position to use for the slice. If not specified (None), the slice is unbounded on the right, i.e. slice until the - end of the string. - repl : str, optional + end of the string. If array-like, it is broadcast. + repl : str or array-like of str, optional String for replacement. If not specified, the sliced region - is replaced with an empty string. + is replaced with an empty string. If array-like, it is broadcast. Returns ------- replaced : same type as values """ - repl = self._obj.dtype.type(repl) + repl = self._stringify(repl) - def f(x): - if len(x[start:stop]) == 0: - local_stop = start + def func(x, istart, istop, irepl): + if len(x[istart:istop]) == 0: + local_stop = istart else: - local_stop = stop - y = self._obj.dtype.type("") - if start is not None: - y += x[:start] - y += repl - if stop is not None: + local_stop = istop + y = self._stringify("") + if istart is not None: + y += x[:istart] + y += irepl + if istop is not None: y += x[local_stop:] return y - return self._apply(f) + return self._apply(func=func, func_args=(start, stop, repl)) + + def cat( + self, + *others, + sep: Union[str, bytes, Any] = "", + ) -> Any: + """ + Concatenate strings elementwise in the DataArray with other strings. + + The other strings can either be string scalars or other array-like. + Dimensions are automatically broadcast together. + + An optional separator `sep` can also be specified. If `sep` is + array-like, it is broadcast against the array and applied elementwise. + + Parameters + ---------- + *others : str or array-like of str + Strings or array-like of strings to concatenate elementwise with + the current DataArray. + sep : str or array-like of str, default: "". + Seperator to use between strings. + It is broadcast in the same way as the other input strings. + If array-like, its dimensions will be placed at the end of the output array dimensions. + + Returns + ------- + concatenated : same type as values + + Examples + -------- + Create a string array + + >>> myarray = xr.DataArray( + ... ["11111", "4"], + ... dims=["X"], + ... ) + + Create some arrays to concatenate with it + + >>> values_1 = xr.DataArray( + ... ["a", "bb", "cccc"], + ... dims=["Y"], + ... ) + >>> values_2 = np.array(3.4) + >>> values_3 = "" + >>> values_4 = np.array("test", dtype=np.unicode_) + + Determine the separator to use + + >>> seps = xr.DataArray( + ... [" ", ", "], + ... dims=["ZZ"], + ... ) + + Concatenate the arrays using the separator + + >>> myarray.str.cat(values_1, values_2, values_3, values_4, sep=seps) + + array([[['11111 a 3.4 test', '11111, a, 3.4, , test'], + ['11111 bb 3.4 test', '11111, bb, 3.4, , test'], + ['11111 cccc 3.4 test', '11111, cccc, 3.4, , test']], + + [['4 a 3.4 test', '4, a, 3.4, , test'], + ['4 bb 3.4 test', '4, bb, 3.4, , test'], + ['4 cccc 3.4 test', '4, cccc, 3.4, , test']]], dtype=' Any: + """ + Concatenate strings in a DataArray along a particular dimension. + + An optional separator `sep` can also be specified. If `sep` is + array-like, it is broadcast against the array and applied elementwise. + + Parameters + ---------- + dim : hashable, optional + Dimension along which the strings should be concatenated. + Only one dimension is allowed at a time. + Optional for 0D or 1D DataArrays, required for multidimensional DataArrays. + sep : str or array-like, default: "". + Seperator to use between strings. + It is broadcast in the same way as the other input strings. + If array-like, its dimensions will be placed at the end of the output array dimensions. + + Returns + ------- + joined : same type as values + + Examples + -------- + Create an array + + >>> values = xr.DataArray( + ... [["a", "bab", "abc"], ["abcd", "", "abcdef"]], + ... dims=["X", "Y"], + ... ) + + Determine the separator + + >>> seps = xr.DataArray( + ... ["-", "_"], + ... dims=["ZZ"], + ... ) + + Join the strings along a given dimension - def capitalize(self): + >>> values.str.join(dim="Y", sep=seps) + + array([['a-bab-abc', 'a_bab_abc'], + ['abcd--abcdef', 'abcd__abcdef']], dtype=' 1 and dim is None: + raise ValueError("Dimension must be specified for multidimensional arrays.") + + if self._obj.ndim > 1: + # Move the target dimension to the start and split along it + dimshifted = list(self._obj.transpose(dim, ...)) + elif self._obj.ndim == 1: + dimshifted = list(self._obj) + else: + dimshifted = [self._obj] + + start, *others = dimshifted + + # concatenate the resulting arrays + return start.str.cat(*others, sep=sep) + + def format( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + """ + Perform python string formatting on each element of the DataArray. + + This is equivalent to calling `str.format` on every element of the + DataArray. The replacement values can either be a string-like + scalar or array-like of string-like values. If array-like, + the values will be broadcast and applied elementwiseto the input + DataArray. + + .. note:: + Array-like values provided as `*args` will have their + dimensions added even if those arguments are not used in any + string formatting. + + .. warning:: + Array-like arguments are only applied elementwise for `*args`. + For `**kwargs`, values are used as-is. + + Parameters + ---------- + *args : str or bytes or array-like of str or bytes + Values for positional formatting. + If array-like, the values are broadcast and applied elementwise. + The dimensions will be placed at the end of the output array dimensions + in the order they are provided. + **kwargs : str or bytes or array-like of str or bytes + Values for keyword-based formatting. + These are **not** broadcast or applied elementwise. + + Returns + ------- + formatted : same type as values + + Examples + -------- + Create an array to format. + + >>> values = xr.DataArray( + ... ["{} is {adj0}", "{} and {} are {adj1}"], + ... dims=["X"], + ... ) + + Set the values to fill. + + >>> noun0 = xr.DataArray( + ... ["spam", "egg"], + ... dims=["Y"], + ... ) + >>> noun1 = xr.DataArray( + ... ["lancelot", "arthur"], + ... dims=["ZZ"], + ... ) + >>> adj0 = "unexpected" + >>> adj1 = "like a duck" + + Insert the values into the array + + >>> values.str.format(noun0, noun1, adj0=adj0, adj1=adj1) + + array([[['spam is unexpected', 'spam is unexpected'], + ['egg is unexpected', 'egg is unexpected']], + + [['spam and lancelot are like a duck', + 'spam and arthur are like a duck'], + ['egg and lancelot are like a duck', + 'egg and arthur are like a duck']]], dtype=' Any: """ Convert strings in the array to be capitalized. @@ -197,9 +684,9 @@ def capitalize(self): ------- capitalized : same type as values """ - return self._apply(lambda x: x.capitalize()) + return self._apply(func=lambda x: x.capitalize()) - def lower(self): + def lower(self) -> Any: """ Convert strings in the array to lowercase. @@ -207,9 +694,9 @@ def lower(self): ------- lowerd : same type as values """ - return self._apply(lambda x: x.lower()) + return self._apply(func=lambda x: x.lower()) - def swapcase(self): + def swapcase(self) -> Any: """ Convert strings in the array to be swapcased. @@ -217,9 +704,9 @@ def swapcase(self): ------- swapcased : same type as values """ - return self._apply(lambda x: x.swapcase()) + return self._apply(func=lambda x: x.swapcase()) - def title(self): + def title(self) -> Any: """ Convert strings in the array to titlecase. @@ -227,9 +714,9 @@ def title(self): ------- titled : same type as values """ - return self._apply(lambda x: x.title()) + return self._apply(func=lambda x: x.title()) - def upper(self): + def upper(self) -> Any: """ Convert strings in the array to uppercase. @@ -237,9 +724,46 @@ def upper(self): ------- uppered : same type as values """ - return self._apply(lambda x: x.upper()) + return self._apply(func=lambda x: x.upper()) - def isalnum(self): + def casefold(self) -> Any: + """ + Convert strings in the array to be casefolded. + + Casefolding is similar to converting to lowercase, + but removes all case distinctions. + This is important in some languages that have more complicated + cases and case conversions. + + Returns + ------- + casefolded : same type as values + """ + return self._apply(func=lambda x: x.casefold()) + + def normalize( + self, + form: str, + ) -> Any: + """ + Return the Unicode normal form for the strings in the datarray. + + For more information on the forms, see the documentation for + :func:`unicodedata.normalize`. + + Parameters + ---------- + form : {"NFC", "NFKC", "NFD", "NFKD"} + Unicode form. + + Returns + ------- + normalized : same type as values + + """ + return self._apply(func=lambda x: normalize(form, x)) + + def isalnum(self) -> Any: """ Check whether all characters in each string are alphanumeric. @@ -248,9 +772,9 @@ def isalnum(self): isalnum : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.isalnum(), dtype=bool) + return self._apply(func=lambda x: x.isalnum(), dtype=bool) - def isalpha(self): + def isalpha(self) -> Any: """ Check whether all characters in each string are alphabetic. @@ -259,9 +783,9 @@ def isalpha(self): isalpha : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.isalpha(), dtype=bool) + return self._apply(func=lambda x: x.isalpha(), dtype=bool) - def isdecimal(self): + def isdecimal(self) -> Any: """ Check whether all characters in each string are decimal. @@ -270,9 +794,9 @@ def isdecimal(self): isdecimal : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.isdecimal(), dtype=bool) + return self._apply(func=lambda x: x.isdecimal(), dtype=bool) - def isdigit(self): + def isdigit(self) -> Any: """ Check whether all characters in each string are digits. @@ -281,9 +805,9 @@ def isdigit(self): isdigit : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.isdigit(), dtype=bool) + return self._apply(func=lambda x: x.isdigit(), dtype=bool) - def islower(self): + def islower(self) -> Any: """ Check whether all characters in each string are lowercase. @@ -292,9 +816,9 @@ def islower(self): islower : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.islower(), dtype=bool) + return self._apply(func=lambda x: x.islower(), dtype=bool) - def isnumeric(self): + def isnumeric(self) -> Any: """ Check whether all characters in each string are numeric. @@ -303,9 +827,9 @@ def isnumeric(self): isnumeric : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.isnumeric(), dtype=bool) + return self._apply(func=lambda x: x.isnumeric(), dtype=bool) - def isspace(self): + def isspace(self) -> Any: """ Check whether all characters in each string are spaces. @@ -314,9 +838,9 @@ def isspace(self): isspace : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.isspace(), dtype=bool) + return self._apply(func=lambda x: x.isspace(), dtype=bool) - def istitle(self): + def istitle(self) -> Any: """ Check whether all characters in each string are titlecase. @@ -325,9 +849,9 @@ def istitle(self): istitle : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.istitle(), dtype=bool) + return self._apply(func=lambda x: x.istitle(), dtype=bool) - def isupper(self): + def isupper(self) -> Any: """ Check whether all characters in each string are uppercase. @@ -336,9 +860,14 @@ def isupper(self): isupper : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.isupper(), dtype=bool) + return self._apply(func=lambda x: x.isupper(), dtype=bool) - def count(self, pat, flags=0): + def count( + self, + pat: Union[str, bytes, Pattern, Any], + flags: int = 0, + case: bool = None, + ) -> Any: """ Count occurrences of pattern in each string of the array. @@ -346,31 +875,49 @@ def count(self, pat, flags=0): pattern is repeated in each of the string elements of the :class:`~xarray.DataArray`. + The pattern `pat` can either be a single ``str`` or `re.Pattern` or + array-like of ``str`` or `re.Pattern`. If array-like, it is broadcast + against the array and applied elementwise. + Parameters ---------- - pat : str - Valid regular expression. + pat : str or re.Pattern or array-like of str or re.Pattern + A string containing a regular expression or a compiled regular + expression object. If array-like, it is broadcast. flags : int, default: 0 - Flags for the `re` module. Use 0 for no flags. For a complete list, - `see here `_. + Flags to pass through to the re module, e.g. `re.IGNORECASE`. + see `compilation-flags `_. + ``0`` means no flags. Flags can be combined with the bitwise or operator ``|``. + Cannot be set if `pat` is a compiled regex. + case : bool, default: True + If True, case sensitive. + Cannot be set if `pat` is a compiled regex. + Equivalent to setting the `re.IGNORECASE` flag. Returns ------- counts : array of int """ - pat = self._obj.dtype.type(pat) - regex = re.compile(pat, flags=flags) - f = lambda x: len(regex.findall(x)) - return self._apply(f, dtype=int) + pat = self._re_compile(pat=pat, flags=flags, case=case) - def startswith(self, pat): + func = lambda x, ipat: len(ipat.findall(x)) + return self._apply(func=func, func_args=(pat,), dtype=int) + + def startswith( + self, + pat: Union[str, bytes, Any], + ) -> Any: """ Test if the start of each string in the array matches a pattern. + The pattern `pat` can either be a ``str`` or array-like of ``str``. + If array-like, it will be broadcast and applied elementwise. + Parameters ---------- pat : str Character sequence. Regular expressions are not accepted. + If array-like, it is broadcast. Returns ------- @@ -378,18 +925,25 @@ def startswith(self, pat): An array of booleans indicating whether the given pattern matches the start of each string element. """ - pat = self._obj.dtype.type(pat) - f = lambda x: x.startswith(pat) - return self._apply(f, dtype=bool) + pat = self._stringify(pat) + func = lambda x, y: x.startswith(y) + return self._apply(func=func, func_args=(pat,), dtype=bool) - def endswith(self, pat): + def endswith( + self, + pat: Union[str, bytes, Any], + ) -> Any: """ Test if the end of each string in the array matches a pattern. + The pattern `pat` can either be a ``str`` or array-like of ``str``. + If array-like, it will be broadcast and applied elementwise. + Parameters ---------- pat : str Character sequence. Regular expressions are not accepted. + If array-like, it is broadcast. Returns ------- @@ -397,100 +951,151 @@ def endswith(self, pat): A Series of booleans indicating whether the given pattern matches the end of each string element. """ - pat = self._obj.dtype.type(pat) - f = lambda x: x.endswith(pat) - return self._apply(f, dtype=bool) + pat = self._stringify(pat) + func = lambda x, y: x.endswith(y) + return self._apply(func=func, func_args=(pat,), dtype=bool) - def pad(self, width, side="left", fillchar=" "): + def pad( + self, + width: Union[int, Any], + side: str = "left", + fillchar: Union[str, bytes, Any] = " ", + ) -> Any: """ Pad strings in the array up to width. + If `width` or 'fillchar` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - width : int + width : int or array-like of int Minimum width of resulting string; additional characters will be - filled with character defined in `fillchar`. + filled with character defined in ``fillchar``. + If array-like, it is broadcast. side : {"left", "right", "both"}, default: "left" Side from which to fill resulting string. - fillchar : str, default: " " - Additional character for filling, default is whitespace. + fillchar : str or array-like of str, default: " " + Additional character for filling, default is a space. + If array-like, it is broadcast. Returns ------- filled : same type as values Array with a minimum number of char in each element. """ - width = int(width) - fillchar = self._obj.dtype.type(fillchar) - if len(fillchar) != 1: - raise TypeError("fillchar must be a character, not str") - if side == "left": - f = lambda s: s.rjust(width, fillchar) + func = self.rjust elif side == "right": - f = lambda s: s.ljust(width, fillchar) + func = self.ljust elif side == "both": - f = lambda s: s.center(width, fillchar) + func = self.center else: # pragma: no cover raise ValueError("Invalid side") - return self._apply(f) + return func(width=width, fillchar=fillchar) + + def _padder( + self, + *, + func: Callable, + width: Union[int, Any], + fillchar: Union[str, bytes, Any] = " ", + ) -> Any: + """ + Wrapper function to handle padding operations + """ + fillchar = self._stringify(fillchar) + + def overfunc(x, iwidth, ifillchar): + if len(ifillchar) != 1: + raise TypeError("fillchar must be a character, not str") + return func(x, int(iwidth), ifillchar) - def center(self, width, fillchar=" "): + return self._apply(func=overfunc, func_args=(width, fillchar)) + + def center( + self, + width: Union[int, Any], + fillchar: Union[str, bytes, Any] = " ", + ) -> Any: """ Pad left and right side of each string in the array. + If `width` or 'fillchar` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - width : int + width : int or array-like of int Minimum width of resulting string; additional characters will be - filled with ``fillchar`` - fillchar : str, default: " " - Additional character for filling, default is whitespace + filled with ``fillchar``. If array-like, it is broadcast. + fillchar : str or array-like of str, default: " " + Additional character for filling, default is a space. + If array-like, it is broadcast. Returns ------- filled : same type as values """ - return self.pad(width, side="both", fillchar=fillchar) + func = self._obj.dtype.type.center + return self._padder(func=func, width=width, fillchar=fillchar) - def ljust(self, width, fillchar=" "): + def ljust( + self, + width: Union[int, Any], + fillchar: Union[str, bytes, Any] = " ", + ) -> Any: """ Pad right side of each string in the array. + If `width` or 'fillchar` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - width : int + width : int or array-like of int Minimum width of resulting string; additional characters will be - filled with ``fillchar`` - fillchar : str, default: " " - Additional character for filling, default is whitespace + filled with ``fillchar``. If array-like, it is broadcast. + fillchar : str or array-like of str, default: " " + Additional character for filling, default is a space. + If array-like, it is broadcast. Returns ------- filled : same type as values """ - return self.pad(width, side="right", fillchar=fillchar) + func = self._obj.dtype.type.ljust + return self._padder(func=func, width=width, fillchar=fillchar) - def rjust(self, width, fillchar=" "): + def rjust( + self, + width: Union[int, Any], + fillchar: Union[str, bytes, Any] = " ", + ) -> Any: """ Pad left side of each string in the array. + If `width` or 'fillchar` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - width : int + width : int or array-like of int Minimum width of resulting string; additional characters will be - filled with ``fillchar`` - fillchar : str, default: " " - Additional character for filling, default is whitespace + filled with ``fillchar``. If array-like, it is broadcast. + fillchar : str or array-like of str, default: " " + Additional character for filling, default is a space. + If array-like, it is broadcast. Returns ------- filled : same type as values """ - return self.pad(width, side="left", fillchar=fillchar) + func = self._obj.dtype.type.rjust + return self._padder(func=func, width=width, fillchar=fillchar) - def zfill(self, width): + def zfill(self, width: Union[int, Any]) -> Any: """ Pad each string in the array by prepending '0' characters. @@ -498,37 +1103,56 @@ def zfill(self, width): left of the string to reach a total string length `width`. Strings in the array with length greater or equal to `width` are unchanged. + If `width` is array-like, it is broadcast against the array and applied + elementwise. + Parameters ---------- - width : int + width : int or array-like of int Minimum length of resulting string; strings with length less - than `width` be prepended with '0' characters. + than `width` be prepended with '0' characters. If array-like, it is broadcast. Returns ------- filled : same type as values """ - return self.pad(width, side="left", fillchar="0") + return self.rjust(width, fillchar="0") - def contains(self, pat, case=True, flags=0, regex=True): + def contains( + self, + pat: Union[str, bytes, Pattern, Any], + case: bool = None, + flags: int = 0, + regex: bool = True, + ) -> Any: """ Test if pattern or regex is contained within each string of the array. Return boolean array based on whether a given pattern or regex is contained within a string of the array. + The pattern `pat` can either be a single ``str`` or `re.Pattern` or + array-like of ``str`` or `re.Pattern`. If array-like, it is broadcast + against the array and applied elementwise. + Parameters ---------- - pat : str - Character sequence or regular expression. + pat : str or re.Pattern or array-like of str or re.Pattern + Character sequence, a string containing a regular expression, + or a compiled regular expression object. If array-like, it is broadcast. case : bool, default: True If True, case sensitive. + Cannot be set if `pat` is a compiled regex. + Equivalent to setting the `re.IGNORECASE` flag. flags : int, default: 0 - Flags to pass through to the re module, e.g. re.IGNORECASE. - ``0`` means no flags. + Flags to pass through to the re module, e.g. `re.IGNORECASE`. + see `compilation-flags `_. + ``0`` means no flags. Flags can be combined with the bitwise or operator ``|``. + Cannot be set if `pat` is a compiled regex. regex : bool, default: True If True, assumes the pat is a regular expression. If False, treats the pat as a literal string. + Cannot be set to `False` if `pat` is a compiled regex. Returns ------- @@ -537,65 +1161,94 @@ def contains(self, pat, case=True, flags=0, regex=True): given pattern is contained within the string of each element of the array. """ - pat = self._obj.dtype.type(pat) - if regex: - if not case: - flags |= re.IGNORECASE + is_compiled_re = _contains_compiled_re(pat) + if is_compiled_re and not regex: + raise ValueError( + "Must use regular expression matching for regular expression object." + ) - regex = re.compile(pat, flags=flags) + if regex: + if not is_compiled_re: + pat = self._re_compile(pat=pat, flags=flags, case=case) - if regex.groups > 0: # pragma: no cover - raise ValueError("This pattern has match groups.") + def func(x, ipat): + if ipat.groups > 0: # pragma: no cover + raise ValueError("This pattern has match groups.") + return bool(ipat.search(x)) - f = lambda x: bool(regex.search(x)) else: - if case: - f = lambda x: pat in x + pat = self._stringify(pat) + if case or case is None: + func = lambda x, ipat: ipat in x + elif self._obj.dtype.char == "U": + uppered = self._obj.str.casefold() + uppat = StringAccessor(pat).casefold() + return uppered.str.contains(uppat, regex=False) else: uppered = self._obj.str.upper() - return uppered.str.contains(pat.upper(), regex=False) + uppat = StringAccessor(pat).upper() + return uppered.str.contains(uppat, regex=False) - return self._apply(f, dtype=bool) + return self._apply(func=func, func_args=(pat,), dtype=bool) - def match(self, pat, case=True, flags=0): + def match( + self, + pat: Union[str, bytes, Pattern, Any], + case: bool = None, + flags: int = 0, + ) -> Any: """ Determine if each string in the array matches a regular expression. + The pattern `pat` can either be a single ``str`` or `re.Pattern` or + array-like of ``str`` or `re.Pattern`. If array-like, it is broadcast + against the array and applied elementwise. + Parameters ---------- - pat : str - Character sequence or regular expression + pat : str or re.Pattern or array-like of str or re.Pattern + A string containing a regular expression or + a compiled regular expression object. If array-like, it is broadcast. case : bool, default: True - If True, case sensitive + If True, case sensitive. + Cannot be set if `pat` is a compiled regex. + Equivalent to setting the `re.IGNORECASE` flag. flags : int, default: 0 - re module flags, e.g. re.IGNORECASE. ``0`` means no flags + Flags to pass through to the re module, e.g. `re.IGNORECASE`. + see `compilation-flags `_. + ``0`` means no flags. Flags can be combined with the bitwise or operator ``|``. + Cannot be set if `pat` is a compiled regex. Returns ------- matched : array of bool """ - if not case: - flags |= re.IGNORECASE + pat = self._re_compile(pat=pat, flags=flags, case=case) - pat = self._obj.dtype.type(pat) - regex = re.compile(pat, flags=flags) - f = lambda x: bool(regex.match(x)) - return self._apply(f, dtype=bool) + func = lambda x, ipat: bool(ipat.match(x)) + return self._apply(func=func, func_args=(pat,), dtype=bool) - def strip(self, to_strip=None, side="both"): + def strip( + self, + to_strip: Union[str, bytes, Any] = None, + side: str = "both", + ) -> Any: """ Remove leading and trailing characters. Strip whitespaces (including newlines) or a set of specified characters from each string in the array from left and/or right sides. + `to_strip` can either be a ``str`` or array-like of ``str``. + If array-like, it will be broadcast and applied elementwise. + Parameters ---------- - to_strip : str or None, default: None + to_strip : str or array-like of str or None, default: None Specifying the set of characters to be removed. All combinations of this set of characters will be stripped. - If None then whitespaces are removed. - side : {"left", "right", "both"}, default: "left" + If None then whitespaces are removed. If array-like, it is broadcast. + side : {"left", "right", "both"}, default: "both" Side from which to strip. Returns @@ -603,32 +1256,38 @@ def strip(self, to_strip=None, side="both"): stripped : same type as values """ if to_strip is not None: - to_strip = self._obj.dtype.type(to_strip) + to_strip = self._stringify(to_strip) if side == "both": - f = lambda x: x.strip(to_strip) + func = lambda x, y: x.strip(y) elif side == "left": - f = lambda x: x.lstrip(to_strip) + func = lambda x, y: x.lstrip(y) elif side == "right": - f = lambda x: x.rstrip(to_strip) + func = lambda x, y: x.rstrip(y) else: # pragma: no cover raise ValueError("Invalid side") - return self._apply(f) + return self._apply(func=func, func_args=(to_strip,)) - def lstrip(self, to_strip=None): + def lstrip( + self, + to_strip: Union[str, bytes, Any] = None, + ) -> Any: """ Remove leading characters. Strip whitespaces (including newlines) or a set of specified characters from each string in the array from the left side. + `to_strip` can either be a ``str`` or array-like of ``str``. + If array-like, it will be broadcast and applied elementwise. + Parameters ---------- - to_strip : str or None, default: None + to_strip : str or array-like of str or None, default: None Specifying the set of characters to be removed. All combinations of this set of characters will be stripped. - If None then whitespaces are removed. + If None then whitespaces are removed. If array-like, it is broadcast. Returns ------- @@ -636,19 +1295,25 @@ def lstrip(self, to_strip=None): """ return self.strip(to_strip, side="left") - def rstrip(self, to_strip=None): + def rstrip( + self, + to_strip: Union[str, bytes, Any] = None, + ) -> Any: """ Remove trailing characters. Strip whitespaces (including newlines) or a set of specified characters from each string in the array from the right side. + `to_strip` can either be a ``str`` or array-like of ``str``. + If array-like, it will be broadcast and applied elementwise. + Parameters ---------- - to_strip : str or None, default: None + to_strip : str or array-like of str or None, default: None Specifying the set of characters to be removed. All combinations of this set of characters will be stripped. - If None then whitespaces are removed. + If None then whitespaces are removed. If array-like, it is broadcast. Returns ------- @@ -656,17 +1321,25 @@ def rstrip(self, to_strip=None): """ return self.strip(to_strip, side="right") - def wrap(self, width, **kwargs): + def wrap( + self, + width: Union[int, Any], + **kwargs, + ) -> Any: """ Wrap long strings in the array in paragraphs with length less than `width`. This method has the same keyword parameters and defaults as :class:`textwrap.TextWrapper`. + If `width` is array-like, it is broadcast against the array and applied + elementwise. + Parameters ---------- - width : int - Maximum line-width + width : int or array-like of int + Maximum line-width. + If array-like, it is broadcast. **kwargs keyword arguments passed into :class:`textwrap.TextWrapper`. @@ -674,11 +1347,15 @@ def wrap(self, width, **kwargs): ------- wrapped : same type as values """ - tw = textwrap.TextWrapper(width=width, **kwargs) - f = lambda x: "\n".join(tw.wrap(x)) - return self._apply(f) + ifunc = lambda x: textwrap.TextWrapper(width=x, **kwargs) + tw = StringAccessor(width)._apply(func=ifunc, dtype=np.object_) + func = lambda x, itw: "\n".join(itw.wrap(x)) + return self._apply(func=func, func_args=(tw,)) - def translate(self, table): + def translate( + self, + table: Mapping[Union[str, bytes], Union[str, bytes]], + ) -> Any: """ Map characters of each string through the given mapping table. @@ -694,40 +1371,59 @@ def translate(self, table): ------- translated : same type as values """ - f = lambda x: x.translate(table) - return self._apply(f) + func = lambda x: x.translate(table) + return self._apply(func=func) - def repeat(self, repeats): + def repeat( + self, + repeats: Union[int, Any], + ) -> Any: """ - Duplicate each string in the array. + Repeat each string in the array. + + If `repeats` is array-like, it is broadcast against the array and applied + elementwise. Parameters ---------- - repeats : int + repeats : int or array-like of int Number of repetitions. + If array-like, it is broadcast. Returns ------- repeated : same type as values Array of repeated string objects. """ - f = lambda x: repeats * x - return self._apply(f) + func = lambda x, y: x * y + return self._apply(func=func, func_args=(repeats,)) - def find(self, sub, start=0, end=None, side="left"): + def find( + self, + sub: Union[str, bytes, Any], + start: Union[int, Any] = 0, + end: Union[int, Any] = None, + side: str = "left", + ) -> Any: """ Return lowest or highest indexes in each strings in the array where the substring is fully contained between [start:end]. Return -1 on failure. + If `start`, `end`, or 'sub` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - sub : str - Substring being searched - start : int - Left edge index - end : int - Right edge index + sub : str or array-like of str + Substring being searched. + If array-like, it is broadcast. + start : int or array-like of int + Left edge index. + If array-like, it is broadcast. + end : int or array-like of int + Right edge index. + If array-like, it is broadcast. side : {"left", "right"}, default: "left" Starting side for search. @@ -735,7 +1431,7 @@ def find(self, sub, start=0, end=None, side="left"): ------- found : array of int """ - sub = self._obj.dtype.type(sub) + sub = self._stringify(sub) if side == "left": method = "find" @@ -744,27 +1440,34 @@ def find(self, sub, start=0, end=None, side="left"): else: # pragma: no cover raise ValueError("Invalid side") - if end is None: - f = lambda x: getattr(x, method)(sub, start) - else: - f = lambda x: getattr(x, method)(sub, start, end) - - return self._apply(f, dtype=int) + func = lambda x, isub, istart, iend: getattr(x, method)(isub, istart, iend) + return self._apply(func=func, func_args=(sub, start, end), dtype=int) - def rfind(self, sub, start=0, end=None): + def rfind( + self, + sub: Union[str, bytes, Any], + start: Union[int, Any] = 0, + end: Union[int, Any] = None, + ) -> Any: """ Return highest indexes in each strings in the array where the substring is fully contained between [start:end]. Return -1 on failure. + If `start`, `end`, or 'sub` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - sub : str - Substring being searched - start : int - Left edge index - end : int - Right edge index + sub : str or array-like of str + Substring being searched. + If array-like, it is broadcast. + start : int or array-like of int + Left edge index. + If array-like, it is broadcast. + end : int or array-like of int + Right edge index. + If array-like, it is broadcast. Returns ------- @@ -772,29 +1475,46 @@ def rfind(self, sub, start=0, end=None): """ return self.find(sub, start=start, end=end, side="right") - def index(self, sub, start=0, end=None, side="left"): + def index( + self, + sub: Union[str, bytes, Any], + start: Union[int, Any] = 0, + end: Union[int, Any] = None, + side: str = "left", + ) -> Any: """ Return lowest or highest indexes in each strings where the substring is fully contained between [start:end]. This is the same as ``str.find`` except instead of returning -1, it raises a ValueError when the substring is not found. + If `start`, `end`, or 'sub` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - sub : str - Substring being searched - start : int - Left edge index - end : int - Right edge index + sub : str or array-like of str + Substring being searched. + If array-like, it is broadcast. + start : int or array-like of int + Left edge index. + If array-like, it is broadcast. + end : int or array-like of int + Right edge index. + If array-like, it is broadcast. side : {"left", "right"}, default: "left" Starting side for search. Returns ------- found : array of int + + Raises + ------ + ValueError + substring is not found """ - sub = self._obj.dtype.type(sub) + sub = self._stringify(sub) if side == "left": method = "index" @@ -803,61 +1523,89 @@ def index(self, sub, start=0, end=None, side="left"): else: # pragma: no cover raise ValueError("Invalid side") - if end is None: - f = lambda x: getattr(x, method)(sub, start) - else: - f = lambda x: getattr(x, method)(sub, start, end) - - return self._apply(f, dtype=int) + func = lambda x, isub, istart, iend: getattr(x, method)(isub, istart, iend) + return self._apply(func=func, func_args=(sub, start, end), dtype=int) - def rindex(self, sub, start=0, end=None): + def rindex( + self, + sub: Union[str, bytes, Any], + start: Union[int, Any] = 0, + end: Union[int, Any] = None, + ) -> Any: """ Return highest indexes in each strings where the substring is fully contained between [start:end]. This is the same as ``str.rfind`` except instead of returning -1, it raises a ValueError when the substring is not found. + If `start`, `end`, or 'sub` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - sub : str - Substring being searched - start : int - Left edge index - end : int - Right edge index + sub : str or array-like of str + Substring being searched. + If array-like, it is broadcast. + start : int or array-like of int + Left edge index. + If array-like, it is broadcast. + end : int or array-like of int + Right edge index. + If array-like, it is broadcast. Returns ------- found : array of int + + Raises + ------ + ValueError + substring is not found """ return self.index(sub, start=start, end=end, side="right") - def replace(self, pat, repl, n=-1, case=None, flags=0, regex=True): + def replace( + self, + pat: Union[str, bytes, Pattern, Any], + repl: Union[str, bytes, Callable, Any], + n: Union[int, Any] = -1, + case: bool = None, + flags: int = 0, + regex: bool = True, + ) -> Any: """ Replace occurrences of pattern/regex in the array with some string. + If `pat`, `repl`, or 'n` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - pat : str or re.Pattern + pat : str or re.Pattern or array-like of str or re.Pattern String can be a character sequence or regular expression. - repl : str or callable + If array-like, it is broadcast. + repl : str or callable or array-like of str or callable Replacement string or a callable. The callable is passed the regex match object and must return a replacement string to be used. See :func:`re.sub`. - n : int, default: -1 + If array-like, it is broadcast. + n : int or array of int, default: -1 Number of replacements to make from start. Use ``-1`` to replace all. - case : bool, default: None - - If True, case sensitive (the default if `pat` is a string) - - Set to False for case insensitive - - Cannot be set if `pat` is a compiled regex + If array-like, it is broadcast. + case : bool, default: True + If True, case sensitive. + Cannot be set if `pat` is a compiled regex. + Equivalent to setting the `re.IGNORECASE` flag. flags : int, default: 0 - - re module flags, e.g. re.IGNORECASE. Use ``0`` for no flags. - - Cannot be set if `pat` is a compiled regex + Flags to pass through to the re module, e.g. `re.IGNORECASE`. + see `compilation-flags `_. + ``0`` means no flags. Flags can be combined with the bitwise or operator ``|``. + Cannot be set if `pat` is a compiled regex. regex : bool, default: True - - If True, assumes the passed-in pattern is a regular expression. - - If False, treats the pattern as a literal string - - Cannot be set to False if `pat` is a compiled regex or `repl` is - a callable. + If True, assumes the passed-in pattern is a regular expression. + If False, treats the pattern as a literal string. + Cannot be set to False if `pat` is a compiled regex or `repl` is + a callable. Returns ------- @@ -865,84 +1613,968 @@ def replace(self, pat, repl, n=-1, case=None, flags=0, regex=True): A copy of the object with all matching occurrences of `pat` replaced by `repl`. """ - if not (_is_str_like(repl) or callable(repl)): # pragma: no cover + if _contains_str_like(repl): + repl = self._stringify(repl) + elif not _contains_callable(repl): # pragma: no cover raise TypeError("repl must be a string or callable") - if _is_str_like(pat): - pat = self._obj.dtype.type(pat) + is_compiled_re = _contains_compiled_re(pat) + if not regex and is_compiled_re: + raise ValueError( + "Cannot use a compiled regex as replacement pattern with regex=False" + ) - if _is_str_like(repl): - repl = self._obj.dtype.type(repl) + if not regex and callable(repl): + raise ValueError("Cannot use a callable replacement when regex=False") - is_compiled_re = isinstance(pat, type(re.compile(""))) if regex: - if is_compiled_re: - if (case is not None) or (flags != 0): - raise ValueError( - "case and flags cannot be set when pat is a compiled regex" - ) - else: - # not a compiled regex - # set default case - if case is None: - case = True - - # add case flag, if provided - if case is False: - flags |= re.IGNORECASE - if is_compiled_re or len(pat) > 1 or flags or callable(repl): - n = n if n >= 0 else 0 - compiled = re.compile(pat, flags=flags) - f = lambda x: compiled.sub(repl=repl, string=x, count=n) - else: - f = lambda x: x.replace(pat, repl, n) + pat = self._re_compile(pat=pat, flags=flags, case=case) + func = lambda x, ipat, irepl, i_n: ipat.sub( + repl=irepl, string=x, count=i_n if i_n >= 0 else 0 + ) + else: + pat = self._stringify(pat) + func = lambda x, ipat, irepl, i_n: x.replace(ipat, irepl, i_n) + return self._apply(func=func, func_args=(pat, repl, n)) + + def extract( + self, + pat: Union[str, bytes, Pattern, Any], + dim: Hashable, + case: bool = None, + flags: int = 0, + ) -> Any: + r""" + Extract the first match of capture groups in the regex pat as a new + dimension in a DataArray. + + For each string in the DataArray, extract groups from the first match + of regular expression pat. + + If `pat` is array-like, it is broadcast against the array and applied + elementwise. + + Parameters + ---------- + pat : str or re.Pattern or array-like of str or re.Pattern + A string containing a regular expression or a compiled regular + expression object. If array-like, it is broadcast. + dim : hashable or None + Name of the new dimension to store the captured strings in. + If None, the pattern must have only one capture group and the + resulting DataArray will have the same size as the original. + case : bool, default: True + If True, case sensitive. + Cannot be set if `pat` is a compiled regex. + Equivalent to setting the `re.IGNORECASE` flag. + flags : int, default: 0 + Flags to pass through to the re module, e.g. `re.IGNORECASE`. + see `compilation-flags `_. + ``0`` means no flags. Flags can be combined with the bitwise or operator ``|``. + Cannot be set if `pat` is a compiled regex. + + Returns + ------- + extracted : same type as values or object array + + Raises + ------ + ValueError + `pat` has no capture groups. + ValueError + `dim` is None and there is more than one capture group. + ValueError + `case` is set when `pat` is a compiled regular expression. + KeyError + The given dimension is already present in the DataArray. + + Examples + -------- + Create a string array + + >>> value = xr.DataArray( + ... [ + ... [ + ... "a_Xy_0", + ... "ab_xY_10-bab_Xy_110-baab_Xy_1100", + ... "abc_Xy_01-cbc_Xy_2210", + ... ], + ... [ + ... "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + ... "", + ... "abcdef_Xy_101-fef_Xy_5543210", + ... ], + ... ], + ... dims=["X", "Y"], + ... ) + + Extract matches + + >>> value.str.extract(r"(\w+)_Xy_(\d*)", dim="match") + + array([[['a', '0'], + ['bab', '110'], + ['abc', '01']], + + [['abcd', ''], + ['', ''], + ['abcdef', '101']]], dtype=' Any: + r""" + Extract all matches of capture groups in the regex pat as new + dimensions in a DataArray. + + For each string in the DataArray, extract groups from all matches + of regular expression pat. + Equivalent to applying re.findall() to all the elements in the DataArray + and splitting the results across dimensions. + + If `pat` is array-like, it is broadcast against the array and applied + elementwise. + + Parameters + ---------- + pat : str or re.Pattern + A string containing a regular expression or a compiled regular + expression object. If array-like, it is broadcast. + group_dim : hashable + Name of the new dimensions corresponding to the capture groups. + This dimension is added to the new DataArray first. + match_dim : hashable + Name of the new dimensions corresponding to the matches for each group. + This dimension is added to the new DataArray second. + case : bool, default: True + If True, case sensitive. + Cannot be set if `pat` is a compiled regex. + Equivalent to setting the `re.IGNORECASE` flag. + flags : int, default: 0 + Flags to pass through to the re module, e.g. `re.IGNORECASE`. + see `compilation-flags `_. + ``0`` means no flags. Flags can be combined with the bitwise or operator ``|``. + Cannot be set if `pat` is a compiled regex. + + Returns + ------- + extracted : same type as values or object array + + Raises + ------ + ValueError + `pat` has no capture groups. + ValueError + `case` is set when `pat` is a compiled regular expression. + KeyError + Either of the given dimensions is already present in the DataArray. + KeyError + The given dimensions names are the same. + + Examples + -------- + Create a string array + + >>> value = xr.DataArray( + ... [ + ... [ + ... "a_Xy_0", + ... "ab_xY_10-bab_Xy_110-baab_Xy_1100", + ... "abc_Xy_01-cbc_Xy_2210", + ... ], + ... [ + ... "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + ... "", + ... "abcdef_Xy_101-fef_Xy_5543210", + ... ], + ... ], + ... dims=["X", "Y"], + ... ) + + Extract matches + + >>> value.str.extractall( + ... r"(\w+)_Xy_(\d*)", group_dim="group", match_dim="match" + ... ) + + array([[[['a', '0'], + ['', ''], + ['', '']], + + [['bab', '110'], + ['baab', '1100'], + ['', '']], + + [['abc', '01'], + ['cbc', '2210'], + ['', '']]], + + + [[['abcd', ''], + ['dcd', '33210'], + ['dccd', '332210']], + + [['', ''], + ['', ''], + ['', '']], + + [['abcdef', '101'], + ['fef', '5543210'], + ['', '']]]], dtype=' Any: + r""" + Find all occurrences of pattern or regular expression in the DataArray. + + Equivalent to applying re.findall() to all the elements in the DataArray. + Results in an object array of lists. + If there is only one capture group, the lists will be a sequence of matches. + If there are multiple capture groups, the lists will be a sequence of lists, + each of which contains a sequence of matches. + + If `pat` is array-like, it is broadcast against the array and applied + elementwise. + + Parameters + ---------- + pat : str or re.Pattern + A string containing a regular expression or a compiled regular + expression object. If array-like, it is broadcast. + case : bool, default: True + If True, case sensitive. + Cannot be set if `pat` is a compiled regex. + Equivalent to setting the `re.IGNORECASE` flag. + flags : int, default: 0 + Flags to pass through to the re module, e.g. `re.IGNORECASE`. + see `compilation-flags `_. + ``0`` means no flags. Flags can be combined with the bitwise or operator ``|``. + Cannot be set if `pat` is a compiled regex. + + Returns + ------- + extracted : object array + + Raises + ------ + ValueError + `pat` has no capture groups. + ValueError + `case` is set when `pat` is a compiled regular expression. + + Examples + -------- + Create a string array + + >>> value = xr.DataArray( + ... [ + ... [ + ... "a_Xy_0", + ... "ab_xY_10-bab_Xy_110-baab_Xy_1100", + ... "abc_Xy_01-cbc_Xy_2210", + ... ], + ... [ + ... "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + ... "", + ... "abcdef_Xy_101-fef_Xy_5543210", + ... ], + ... ], + ... dims=["X", "Y"], + ... ) + + Extract matches + + >>> value.str.findall(r"(\w+)_Xy_(\d*)") + + array([[list([('a', '0')]), list([('bab', '110'), ('baab', '1100')]), + list([('abc', '01'), ('cbc', '2210')])], + [list([('abcd', ''), ('dcd', '33210'), ('dccd', '332210')]), + list([]), list([('abcdef', '101'), ('fef', '5543210')])]], + dtype=object) + Dimensions without coordinates: X, Y + + See Also + -------- + DataArray.str.extract + DataArray.str.extractall + re.compile + re.findall + pandas.Series.str.findall + """ + pat = self._re_compile(pat=pat, flags=flags, case=case) + + def func(x, ipat): + if ipat.groups == 0: + raise ValueError("No capture groups found in pattern.") + + return ipat.findall(x) + + return self._apply(func=func, func_args=(pat,), dtype=np.object_) + + def _partitioner( + self, + *, + func: Callable, + dim: Hashable, + sep: Optional[Union[str, bytes, Any]], + ) -> Any: + """ + Implements logic for `partition` and `rpartition`. + """ + sep = self._stringify(sep) + + if dim is None: + listfunc = lambda x, isep: list(func(x, isep)) + return self._apply(func=listfunc, func_args=(sep,), dtype=np.object_) + + # _apply breaks on an empty array in this case + if not self._obj.size: + return self._obj.copy().expand_dims({dim: 0}, axis=-1) + + arrfunc = lambda x, isep: np.array(func(x, isep), dtype=self._obj.dtype) + + # dtype MUST be object or strings can be truncated + # See: https://github.com/numpy/numpy/issues/8352 + return self._apply( + func=arrfunc, + func_args=(sep,), + dtype=np.object_, + output_core_dims=[[dim]], + output_sizes={dim: 3}, + ).astype(self._obj.dtype.kind) + + def partition( + self, + dim: Optional[Hashable], + sep: Union[str, bytes, Any] = " ", + ) -> Any: + """ + Split the strings in the DataArray at the first occurrence of separator `sep`. + + This method splits the string at the first occurrence of `sep`, + and returns 3 elements containing the part before the separator, + the separator itself, and the part after the separator. + If the separator is not found, return 3 elements containing the string itself, + followed by two empty strings. + + If `sep` is array-like, it is broadcast against the array and applied + elementwise. + + Parameters + ---------- + dim : hashable or None + Name for the dimension to place the 3 elements in. + If `None`, place the results as list elements in an object DataArray. + sep : str, default: " " + String to split on. + If array-like, it is broadcast. + + Returns + ------- + partitioned : same type as values or object array + + See Also + -------- + DataArray.str.rpartition + str.partition + pandas.Series.str.partition + """ + return self._partitioner(func=self._obj.dtype.type.partition, dim=dim, sep=sep) + + def rpartition( + self, + dim: Optional[Hashable], + sep: Union[str, bytes, Any] = " ", + ) -> Any: + """ + Split the strings in the DataArray at the last occurrence of separator `sep`. + + This method splits the string at the last occurrence of `sep`, + and returns 3 elements containing the part before the separator, + the separator itself, and the part after the separator. + If the separator is not found, return 3 elements containing two empty strings, + followed by the string itself. + + If `sep` is array-like, it is broadcast against the array and applied + elementwise. + + Parameters + ---------- + dim : hashable or None + Name for the dimension to place the 3 elements in. + If `None`, place the results as list elements in an object DataArray. + sep : str, default: " " + String to split on. + If array-like, it is broadcast. + + Returns + ------- + rpartitioned : same type as values or object array + + See Also + -------- + DataArray.str.partition + str.rpartition + pandas.Series.str.rpartition + """ + return self._partitioner(func=self._obj.dtype.type.rpartition, dim=dim, sep=sep) + + def _splitter( + self, + *, + func: Callable, + pre: bool, + dim: Hashable, + sep: Optional[Union[str, bytes, Any]], + maxsplit: int, + ) -> Any: + """ + Implements logic for `split` and `rsplit`. + """ + if sep is not None: + sep = self._stringify(sep) + + if dim is None: + f_none = lambda x, isep: func(x, isep, maxsplit) + return self._apply(func=f_none, func_args=(sep,), dtype=np.object_) + + # _apply breaks on an empty array in this case + if not self._obj.size: + return self._obj.copy().expand_dims({dim: 0}, axis=-1) + + f_count = lambda x, isep: max(len(func(x, isep, maxsplit)), 1) + maxsplit = ( + self._apply(func=f_count, func_args=(sep,), dtype=np.int_).max().data.item() + - 1 + ) + + def _dosplit(mystr, sep, maxsplit=maxsplit, dtype=self._obj.dtype): + res = func(mystr, sep, maxsplit) + if len(res) < maxsplit + 1: + pad = [""] * (maxsplit + 1 - len(res)) + if pre: + res += pad + else: + res = pad + res + return np.array(res, dtype=dtype) + + # dtype MUST be object or strings can be truncated + # See: https://github.com/numpy/numpy/issues/8352 + return self._apply( + func=_dosplit, + func_args=(sep,), + dtype=np.object_, + output_core_dims=[[dim]], + output_sizes={dim: maxsplit}, + ).astype(self._obj.dtype.kind) + + def split( + self, + dim: Optional[Hashable], + sep: Union[str, bytes, Any] = None, + maxsplit: int = -1, + ) -> Any: + r""" + Split strings in a DataArray around the given separator/delimiter `sep`. + + Splits the string in the DataArray from the beginning, + at the specified delimiter string. + + If `sep` is array-like, it is broadcast against the array and applied + elementwise. + + Parameters + ---------- + dim : hashable or None + Name for the dimension to place the results in. + If `None`, place the results as list elements in an object DataArray. + sep : str, default: None + String to split on. If ``None`` (the default), split on any whitespace. + If array-like, it is broadcast. + maxsplit : int, default: -1 + Limit number of splits in output, starting from the beginning. + If -1 (the default), return all splits. + + Returns + ------- + splitted : same type as values or object array + + Examples + -------- + Create a string DataArray + + >>> values = xr.DataArray( + ... [ + ... ["abc def", "spam\t\teggs\tswallow", "red_blue"], + ... ["test0\ntest1\ntest2\n\ntest3", "", "abra ka\nda\tbra"], + ... ], + ... dims=["X", "Y"], + ... ) + + Split once and put the results in a new dimension + + >>> values.str.split(dim="splitted", maxsplit=1) + + array([[['abc', 'def'], + ['spam', 'eggs\tswallow'], + ['red_blue', '']], + + [['test0', 'test1\ntest2\n\ntest3'], + ['', ''], + ['abra', 'ka\nda\tbra']]], dtype='>> values.str.split(dim="splitted") + + array([[['abc', 'def', '', ''], + ['spam', 'eggs', 'swallow', ''], + ['red_blue', '', '', '']], + + [['test0', 'test1', 'test2', 'test3'], + ['', '', '', ''], + ['abra', 'ka', 'da', 'bra']]], dtype='>> values.str.split(dim=None, maxsplit=1) + + array([[list(['abc', 'def']), list(['spam', 'eggs\tswallow']), + list(['red_blue'])], + [list(['test0', 'test1\ntest2\n\ntest3']), list([]), + list(['abra', 'ka\nda\tbra'])]], dtype=object) + Dimensions without coordinates: X, Y + + Split as many times as needed and put the results in a list + + >>> values.str.split(dim=None) + + array([[list(['abc', 'def']), list(['spam', 'eggs', 'swallow']), + list(['red_blue'])], + [list(['test0', 'test1', 'test2', 'test3']), list([]), + list(['abra', 'ka', 'da', 'bra'])]], dtype=object) + Dimensions without coordinates: X, Y + + Split only on spaces + + >>> values.str.split(dim="splitted", sep=" ") + + array([[['abc', 'def', ''], + ['spam\t\teggs\tswallow', '', ''], + ['red_blue', '', '']], + + [['test0\ntest1\ntest2\n\ntest3', '', ''], + ['', '', ''], + ['abra', '', 'ka\nda\tbra']]], dtype=' Any: + r""" + Split strings in a DataArray around the given separator/delimiter `sep`. + + Splits the string in the DataArray from the end, + at the specified delimiter string. + + If `sep` is array-like, it is broadcast against the array and applied + elementwise. + + Parameters + ---------- + dim : hashable or None + Name for the dimension to place the results in. + If `None`, place the results as list elements in an object DataArray + sep : str, default: None + String to split on. If ``None`` (the default), split on any whitespace. + If array-like, it is broadcast. + maxsplit : int, default: -1 + Limit number of splits in output, starting from the end. + If -1 (the default), return all splits. + The final number of split values may be less than this if there are no + DataArray elements with that many values. + + Returns + ------- + rsplitted : same type as values or object array + + Examples + -------- + Create a string DataArray + + >>> values = xr.DataArray( + ... [ + ... ["abc def", "spam\t\teggs\tswallow", "red_blue"], + ... ["test0\ntest1\ntest2\n\ntest3", "", "abra ka\nda\tbra"], + ... ], + ... dims=["X", "Y"], + ... ) + + Split once and put the results in a new dimension + + >>> values.str.rsplit(dim="splitted", maxsplit=1) + + array([[['abc', 'def'], + ['spam\t\teggs', 'swallow'], + ['', 'red_blue']], + + [['test0\ntest1\ntest2', 'test3'], + ['', ''], + ['abra ka\nda', 'bra']]], dtype='>> values.str.rsplit(dim="splitted") + + array([[['', '', 'abc', 'def'], + ['', 'spam', 'eggs', 'swallow'], + ['', '', '', 'red_blue']], + + [['test0', 'test1', 'test2', 'test3'], + ['', '', '', ''], + ['abra', 'ka', 'da', 'bra']]], dtype='>> values.str.rsplit(dim=None, maxsplit=1) + + array([[list(['abc', 'def']), list(['spam\t\teggs', 'swallow']), + list(['red_blue'])], + [list(['test0\ntest1\ntest2', 'test3']), list([]), + list(['abra ka\nda', 'bra'])]], dtype=object) + Dimensions without coordinates: X, Y + + Split as many times as needed and put the results in a list + + >>> values.str.rsplit(dim=None) + + array([[list(['abc', 'def']), list(['spam', 'eggs', 'swallow']), + list(['red_blue'])], + [list(['test0', 'test1', 'test2', 'test3']), list([]), + list(['abra', 'ka', 'da', 'bra'])]], dtype=object) + Dimensions without coordinates: X, Y + + Split only on spaces + + >>> values.str.rsplit(dim="splitted", sep=" ") + + array([[['', 'abc', 'def'], + ['', '', 'spam\t\teggs\tswallow'], + ['', '', 'red_blue']], + + [['', '', 'test0\ntest1\ntest2\n\ntest3'], + ['', '', ''], + ['abra', '', 'ka\nda\tbra']]], dtype=' Any: + """ + Return DataArray of dummy/indicator variables. + + Each string in the DataArray is split at `sep`. + A new dimension is created with coordinates for each unique result, + and the corresponding element of that dimension is `True` if + that result is present and `False` if not. + + If `sep` is array-like, it is broadcast against the array and applied + elementwise. + + Parameters + ---------- + dim : hashable + Name for the dimension to place the results in. + sep : str, default: "|". + String to split on. + If array-like, it is broadcast. + + Returns + ------- + dummies : array of bool + + Examples + -------- + Create a string array + + >>> values = xr.DataArray( + ... [ + ... ["a|ab~abc|abc", "ab", "a||abc|abcd"], + ... ["abcd|ab|a", "abc|ab~abc", "|a"], + ... ], + ... dims=["X", "Y"], + ... ) + + Extract dummy values + + >>> values.str.get_dummies(dim="dummies") + + array([[[ True, False, True, False, True], + [False, True, False, False, False], + [ True, False, True, True, False]], + + [[ True, True, False, True, False], + [False, False, True, False, True], + [ True, False, False, False, False]]]) + Coordinates: + * dummies (dummies) Any: """ Decode character string in the array using indicated encoding. Parameters ---------- encoding : str + The encoding to use. + Please see the Python documentation `codecs standard encoders `_ + section for a list of encodings handlers. errors : str, optional + The handler for encoding errors. + Please see the Python documentation `codecs error handlers `_ + for a list of error handlers. Returns ------- decoded : same type as values """ if encoding in _cpython_optimized_decoders: - f = lambda x: x.decode(encoding, errors) + func = lambda x: x.decode(encoding, errors) else: decoder = codecs.getdecoder(encoding) - f = lambda x: decoder(x, errors)[0] - return self._apply(f, dtype=np.str_) + func = lambda x: decoder(x, errors)[0] + return self._apply(func=func, dtype=np.str_) - def encode(self, encoding, errors="strict"): + def encode( + self, + encoding: str, + errors: str = "strict", + ) -> Any: """ Encode character string in the array using indicated encoding. Parameters ---------- encoding : str + The encoding to use. + Please see the Python documentation `codecs standard encoders `_ + section for a list of encodings handlers. errors : str, optional + The handler for encoding errors. + Please see the Python documentation `codecs error handlers `_ + for a list of error handlers. Returns ------- encoded : same type as values """ if encoding in _cpython_optimized_encoders: - f = lambda x: x.encode(encoding, errors) + func = lambda x: x.encode(encoding, errors) else: encoder = codecs.getencoder(encoding) - f = lambda x: encoder(x, errors)[0] - return self._apply(f, dtype=np.bytes_) + func = lambda x: encoder(x, errors)[0] + return self._apply(func=func, dtype=np.bytes_) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index debf3aad96a..a53ac094253 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -17,9 +17,9 @@ import numpy as np import pandas as pd -from . import dtypes, utils -from .indexing import get_indexer_nd -from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str +from . import dtypes +from .indexes import Index, PandasIndex, get_indexer_nd +from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str, safe_cast_to_index from .variable import IndexVariable, Variable if TYPE_CHECKING: @@ -30,11 +30,11 @@ DataAlignable = TypeVar("DataAlignable", bound=DataWithCoords) -def _get_joiner(join): +def _get_joiner(join, index_cls): if join == "outer": - return functools.partial(functools.reduce, pd.Index.union) + return functools.partial(functools.reduce, index_cls.union) elif join == "inner": - return functools.partial(functools.reduce, pd.Index.intersection) + return functools.partial(functools.reduce, index_cls.intersection) elif join == "left": return operator.itemgetter(0) elif join == "right": @@ -47,25 +47,28 @@ def _get_joiner(join): # We rewrite all indexes and then use join='left' return operator.itemgetter(0) else: - raise ValueError("invalid value for join: %s" % join) + raise ValueError(f"invalid value for join: {join}") def _override_indexes(objects, all_indexes, exclude): for dim, dim_indexes in all_indexes.items(): if dim not in exclude: - lengths = {index.size for index in dim_indexes} + lengths = { + getattr(index, "size", index.to_pandas_index().size) + for index in dim_indexes + } if len(lengths) != 1: raise ValueError( - "Indexes along dimension %r don't have the same length." - " Cannot use join='override'." % dim + f"Indexes along dimension {dim!r} don't have the same length." + " Cannot use join='override'." ) objects = list(objects) for idx, obj in enumerate(objects[1:]): - new_indexes = {} - for dim in obj.indexes: - if dim not in exclude: - new_indexes[dim] = all_indexes[dim][0] + new_indexes = { + dim: all_indexes[dim][0] for dim in obj.xindexes if dim not in exclude + } + objects[idx + 1] = obj._overwrite_indexes(new_indexes) return objects @@ -135,8 +138,6 @@ def align( Examples -------- - - >>> import xarray as xr >>> x = xr.DataArray( ... [[25, 35], [10, 24]], ... dims=("lat", "lon"), @@ -285,7 +286,7 @@ def align( if dim not in exclude: all_coords[dim].append(obj.coords[dim]) try: - index = obj.indexes[dim] + index = obj.xindexes[dim] except KeyError: unlabeled_dim_sizes[dim].add(obj.sizes[dim]) else: @@ -299,11 +300,12 @@ def align( # - It ensures it's possible to do operations that don't require alignment # on indexes with duplicate values (which cannot be reindexed with # pandas). This is useful, e.g., for overwriting such duplicate indexes. - joiner = _get_joiner(join) joined_indexes = {} for dim, matching_indexes in all_indexes.items(): if dim in indexes: - index = utils.safe_cast_to_index(indexes[dim]) + index, _ = PandasIndex.from_pandas_index( + safe_cast_to_index(indexes[dim]), dim + ) if ( any(not index.equals(other) for other in matching_indexes) or dim in unlabeled_dim_sizes @@ -319,37 +321,45 @@ def align( ): if join == "exact": raise ValueError(f"indexes along dimension {dim!r} are not equal") + joiner = _get_joiner(join, type(matching_indexes[0])) index = joiner(matching_indexes) # make sure str coords are not cast to object - index = maybe_coerce_to_str(index, all_coords[dim]) + index = maybe_coerce_to_str(index.to_pandas_index(), all_coords[dim]) joined_indexes[dim] = index else: index = all_coords[dim][0] if dim in unlabeled_dim_sizes: unlabeled_sizes = unlabeled_dim_sizes[dim] - labeled_size = index.size + # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647 + if isinstance(index, PandasIndex): + labeled_size = index.to_pandas_index().size + else: + labeled_size = index.size if len(unlabeled_sizes | {labeled_size}) > 1: raise ValueError( - "arguments without labels along dimension %r cannot be " - "aligned because they have different dimension size(s) %r " - "than the size of the aligned dimension labels: %r" - % (dim, unlabeled_sizes, labeled_size) + f"arguments without labels along dimension {dim!r} cannot be " + f"aligned because they have different dimension size(s) {unlabeled_sizes!r} " + f"than the size of the aligned dimension labels: {labeled_size!r}" ) - for dim in unlabeled_dim_sizes: - if dim not in all_indexes: - sizes = unlabeled_dim_sizes[dim] - if len(sizes) > 1: - raise ValueError( - "arguments without labels along dimension %r cannot be " - "aligned because they have different dimension sizes: %r" - % (dim, sizes) - ) + for dim, sizes in unlabeled_dim_sizes.items(): + if dim not in all_indexes and len(sizes) > 1: + raise ValueError( + f"arguments without labels along dimension {dim!r} cannot be " + f"aligned because they have different dimension sizes: {sizes!r}" + ) result = [] for obj in objects: - valid_indexers = {k: v for k, v in joined_indexes.items() if k in obj.dims} + # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647 + valid_indexers = {} + for k, index in joined_indexes.items(): + if k in obj.dims: + if isinstance(index, Index): + valid_indexers[k] = index.to_pandas_index() + else: + valid_indexers[k] = index if not valid_indexers: # fast path for no reindexing necessary new_obj = obj.copy(deep=copy) @@ -470,7 +480,11 @@ def reindex_like_indexers( ValueError If any dimensions without labels have different sizes. """ - indexers = {k: v for k, v in other.indexes.items() if k in target.dims} + # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647 + # this doesn't support yet indexes other than pd.Index + indexers = { + k: v.to_pandas_index() for k, v in other.xindexes.items() if k in target.dims + } for dim in other.dims: if dim not in indexers and dim in target.dims: @@ -479,8 +493,7 @@ def reindex_like_indexers( if other_size != target_size: raise ValueError( "different size for unlabeled " - "dimension on argument %r: %r vs %r" - % (dim, other_size, target_size) + f"dimension on argument {dim!r}: {other_size!r} vs {target_size!r}" ) return indexers @@ -488,14 +501,14 @@ def reindex_like_indexers( def reindex_variables( variables: Mapping[Any, Variable], sizes: Mapping[Any, int], - indexes: Mapping[Any, pd.Index], + indexes: Mapping[Any, Index], indexers: Mapping, method: Optional[str] = None, tolerance: Any = None, copy: bool = True, fill_value: Optional[Any] = dtypes.NA, sparse: bool = False, -) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, pd.Index]]: +) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]: """Conform a dictionary of aligned variables onto a new set of variables, filling in missing values with NaN. @@ -532,7 +545,7 @@ def reindex_variables( the input. In either case, new xarray objects are always returned. fill_value : scalar, optional Value to use for newly missing values - sparse: bool, optional + sparse : bool, optional Use an sparse-array Returns @@ -560,15 +573,17 @@ def reindex_variables( "from that to be indexed along {:s}".format(str(indexer.dims), dim) ) - target = new_indexes[dim] = utils.safe_cast_to_index(indexers[dim]) + target = safe_cast_to_index(indexers[dim]) + new_indexes[dim] = PandasIndex(target, dim) if dim in indexes: - index = indexes[dim] + # TODO (benbovy - flexible indexes): support other indexes than pd.Index? + index = indexes[dim].to_pandas_index() if not index.is_unique: raise ValueError( - "cannot reindex or align along dimension %r because the " - "index has duplicate values" % dim + f"cannot reindex or align along dimension {dim!r} because the " + "index has duplicate values" ) int_indexer = get_indexer_nd(index, target, method, tolerance) @@ -595,9 +610,9 @@ def reindex_variables( new_size = indexers[dim].size if existing_size != new_size: raise ValueError( - "cannot reindex or align along dimension %r without an " - "index because its size %r is different from the size of " - "the new index %r" % (dim, existing_size, new_size) + f"cannot reindex or align along dimension {dim!r} without an " + f"index because its size {existing_size!r} is different from the size of " + f"the new index {new_size!r}" ) for name, var in variables.items(): @@ -704,7 +719,6 @@ def broadcast(*args, exclude=None): Examples -------- - Broadcast two data arrays against one another to fill out their dimensions: >>> a = xr.DataArray([1, 2, 3], dims="x") @@ -749,8 +763,6 @@ def broadcast(*args, exclude=None): args = align(*args, join="outer", copy=False, exclude=exclude) dims_map, common_coords = _get_broadcast_dims_map_common_coords(args, exclude) - result = [] - for arg in args: - result.append(_broadcast_helper(arg, exclude, dims_map, common_coords)) + result = [_broadcast_helper(arg, exclude, dims_map, common_coords) for arg in args] return tuple(result) diff --git a/xarray/core/arithmetic.py b/xarray/core/arithmetic.py index 8eba0fe7919..27ec5ab8dd9 100644 --- a/xarray/core/arithmetic.py +++ b/xarray/core/arithmetic.py @@ -3,9 +3,18 @@ import numpy as np +# _typed_ops.py is a generated file +from ._typed_ops import ( + DataArrayGroupByOpsMixin, + DataArrayOpsMixin, + DatasetGroupByOpsMixin, + DatasetOpsMixin, + VariableOpsMixin, +) +from .common import ImplementsArrayReduce, ImplementsDatasetReduce +from .ops import IncludeCumMethods, IncludeNumpySameMethods, IncludeReduceMethods from .options import OPTIONS, _get_keep_attrs from .pycompat import dask_array_type -from .utils import not_implemented class SupportsArithmetic: @@ -80,26 +89,61 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): keep_attrs=_get_keep_attrs(default=True), ) - # this has no runtime function - these are listed so IDEs know these - # methods are defined and don't warn on these operations - __lt__ = ( - __le__ - ) = ( - __ge__ - ) = ( - __gt__ - ) = ( - __add__ - ) = ( - __sub__ - ) = ( - __mul__ - ) = ( - __truediv__ - ) = ( - __floordiv__ - ) = ( - __mod__ - ) = ( - __pow__ - ) = __and__ = __xor__ = __or__ = __div__ = __eq__ = __ne__ = not_implemented + +class VariableArithmetic( + ImplementsArrayReduce, + IncludeReduceMethods, + IncludeCumMethods, + IncludeNumpySameMethods, + SupportsArithmetic, + VariableOpsMixin, +): + __slots__ = () + # prioritize our operations over those of numpy.ndarray (priority=0) + __array_priority__ = 50 + + +class DatasetArithmetic( + ImplementsDatasetReduce, + IncludeReduceMethods, + IncludeCumMethods, + SupportsArithmetic, + DatasetOpsMixin, +): + __slots__ = () + __array_priority__ = 50 + + +class DataArrayArithmetic( + ImplementsArrayReduce, + IncludeReduceMethods, + IncludeCumMethods, + IncludeNumpySameMethods, + SupportsArithmetic, + DataArrayOpsMixin, +): + __slots__ = () + # priority must be higher than Variable to properly work with binary ufuncs + __array_priority__ = 60 + + +class DataArrayGroupbyArithmetic( + ImplementsArrayReduce, + IncludeReduceMethods, + SupportsArithmetic, + DataArrayGroupByOpsMixin, +): + __slots__ = () + + +class DatasetGroupbyArithmetic( + ImplementsDatasetReduce, + IncludeReduceMethods, + SupportsArithmetic, + DatasetGroupByOpsMixin, +): + __slots__ = () + + +class CoarsenArithmetic(IncludeReduceMethods): + __slots__ = () diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 86ed1870302..7e1565e50de 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -1,4 +1,5 @@ import itertools +import warnings from collections import Counter import pandas as pd @@ -8,11 +9,11 @@ from .dataarray import DataArray from .dataset import Dataset from .merge import merge +from .utils import iterate_nested def _infer_concat_order_from_positions(datasets): - combined_ids = dict(_infer_tile_ids_from_nested_list(datasets, ())) - return combined_ids + return dict(_infer_tile_ids_from_nested_list(datasets, ())) def _infer_tile_ids_from_nested_list(entry, current_pos): @@ -44,6 +45,18 @@ def _infer_tile_ids_from_nested_list(entry, current_pos): yield current_pos, entry +def _ensure_same_types(series, dim): + + if series.dtype == object: + types = set(series.map(type)) + if len(types) > 1: + types = ", ".join(t.__name__ for t in types) + raise TypeError( + f"Cannot combine along dimension '{dim}' with mixed types." + f" Found: {types}." + ) + + def _infer_concat_order_from_coords(datasets): concat_dims = [] @@ -57,13 +70,16 @@ def _infer_concat_order_from_coords(datasets): if dim in ds0: # Need to read coordinate values to do ordering - indexes = [ds.indexes.get(dim) for ds in datasets] + indexes = [ds.xindexes.get(dim) for ds in datasets] if any(index is None for index in indexes): raise ValueError( "Every dimension needs a coordinate for " "inferring concatenation order" ) + # TODO (benbovy, flexible indexes): support flexible indexes? + indexes = [index.to_pandas_index() for index in indexes] + # If dimension coordinate values are same on every dataset then # should be leaving this dimension alone (it's just a "bystander") if not all(index.equals(indexes[0]) for index in indexes[1:]): @@ -88,11 +104,15 @@ def _infer_concat_order_from_coords(datasets): raise ValueError("Cannot handle size zero dimensions") first_items = pd.Index([index[0] for index in indexes]) + series = first_items.to_series() + + # ensure series does not contain mixed types, e.g. cftime calendars + _ensure_same_types(series, dim) + # Sort datasets along dim # We want rank but with identical elements given identical # position indices - they should be concatenated along another # dimension, not along this one - series = first_items.to_series() rank = series.rank( method="dense", ascending=ascending, numeric_only=False ) @@ -124,7 +144,7 @@ def _check_dimension_depth_tile_ids(combined_tile_ids): nesting_depths = [len(tile_id) for tile_id in tile_ids] if not nesting_depths: nesting_depths = [0] - if not set(nesting_depths) == {nesting_depths[0]}: + if set(nesting_depths) != {nesting_depths[0]}: raise ValueError( "The supplied objects do not form a hypercube because" " sub-lists do not have consistent depths" @@ -356,7 +376,7 @@ def combine_nested( To concatenate along multiple dimensions the datasets must be passed as a nested list-of-lists, with a depth equal to the length of ``concat_dims``. - ``manual_combine`` will concatenate along the top-level list first. + ``combine_nested`` will concatenate along the top-level list first. Useful for combining datasets from a set of nested directories, or for collecting the output of a simulation parallelized along multiple @@ -412,17 +432,23 @@ def combine_nested( - "override": if indexes are of same size, rewrite indexes to be those of the first object with that dimension. Indexes for the same dimension must have the same size in all objects. - combine_attrs : {"drop", "identical", "no_conflicts", "override"}, \ - default: "drop" - String indicating how to combine attrs of the objects being merged: + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ + "override"} or callable, default: "drop" + A callable or a string indicating how to combine attrs of the objects being + merged: - "drop": empty attrs on returned Dataset. - "identical": all attrs must be the same on every object. - "no_conflicts": attrs from all objects are combined, any that have the same name must also have the same value. + - "drop_conflicts": attrs from all objects are combined, any that have + the same name but different values are dropped. - "override": skip comparing and copy attrs from the first dataset to the result. + If a callable, it must expect a sequence of ``attrs`` dicts and a context object + as its only parameters. + Returns ------- combined : xarray.Dataset @@ -478,7 +504,7 @@ def combine_nested( temperature (x, y) float64 1.764 0.4002 -0.1032 ... 0.04576 -0.1872 precipitation (x, y) float64 1.868 -0.9773 0.761 ... -0.7422 0.1549 0.3782 - ``manual_combine`` can also be used to explicitly merge datasets with + ``combine_nested`` can also be used to explicitly merge datasets with different variables. For example if we have 4 datasets, which are divided along two times, and contain two different variables, we can pass ``None`` to ``concat_dim`` to specify the dimension of the nested list over which @@ -519,10 +545,19 @@ def combine_nested( concat merge """ + mixed_datasets_and_arrays = any( + isinstance(obj, Dataset) for obj in iterate_nested(datasets) + ) and any( + isinstance(obj, DataArray) and obj.name is None + for obj in iterate_nested(datasets) + ) + if mixed_datasets_and_arrays: + raise ValueError("Can't combine datasets with unnamed arrays.") + if isinstance(concat_dim, (str, DataArray)) or concat_dim is None: concat_dim = [concat_dim] - # The IDs argument tells _manual_combine that datasets aren't yet sorted + # The IDs argument tells _nested_combine that datasets aren't yet sorted return _nested_combine( datasets, concat_dims=concat_dim, @@ -540,18 +575,79 @@ def vars_as_keys(ds): return tuple(sorted(ds)) -def combine_by_coords( +def _combine_single_variable_hypercube( datasets, + fill_value=dtypes.NA, + data_vars="all", + coords="different", + compat="no_conflicts", + join="outer", + combine_attrs="no_conflicts", +): + """ + Attempt to combine a list of Datasets into a hypercube using their + coordinates. + + All provided Datasets must belong to a single variable, ie. must be + assigned the same variable name. This precondition is not checked by this + function, so the caller is assumed to know what it's doing. + + This function is NOT part of the public API. + """ + if len(datasets) == 0: + raise ValueError( + "At least one Dataset is required to resolve variable names " + "for combined hypercube." + ) + + combined_ids, concat_dims = _infer_concat_order_from_coords(list(datasets)) + + if fill_value is None: + # check that datasets form complete hypercube + _check_shape_tile_ids(combined_ids) + else: + # check only that all datasets have same dimension depth for these + # vars + _check_dimension_depth_tile_ids(combined_ids) + + # Concatenate along all of concat_dims one by one to create single ds + concatenated = _combine_nd( + combined_ids, + concat_dims=concat_dims, + data_vars=data_vars, + coords=coords, + compat=compat, + fill_value=fill_value, + join=join, + combine_attrs=combine_attrs, + ) + + # Check the overall coordinates are monotonically increasing + for dim in concat_dims: + indexes = concatenated.indexes.get(dim) + if not (indexes.is_monotonic_increasing or indexes.is_monotonic_decreasing): + raise ValueError( + "Resulting object does not have monotonic" + " global indexes along dimension {}".format(dim) + ) + + return concatenated + + +# TODO remove empty list default param after version 0.21, see PR4696 +def combine_by_coords( + data_objects=[], compat="no_conflicts", data_vars="all", coords="different", fill_value=dtypes.NA, join="outer", combine_attrs="no_conflicts", + datasets=None, ): """ - Attempt to auto-magically combine the given datasets into one by using - dimension coordinates. + Attempt to auto-magically combine the given datasets (or data arrays) + into one by using dimension coordinates. This method attempts to combine a group of datasets along any number of dimensions into a single entity by inspecting coords and metadata and using @@ -565,7 +661,7 @@ def combine_by_coords( Aligns coordinates, but different variables on datasets can cause it to fail under some scenarios. In complex cases, you may need to clean up - your data and use concat/merge explicitly (also see `manual_combine`). + your data and use concat/merge explicitly (also see `combine_nested`). Works well if, for example, you have N years of data and M data variables, and each combination of a distinct time period and set of data variables is @@ -575,8 +671,9 @@ def combine_by_coords( Parameters ---------- - datasets : sequence of xarray.Dataset - Dataset objects to combine. + data_objects : sequence of xarray.Dataset or sequence of xarray.DataArray + Data objects to combine. + compat : {"identical", "equals", "broadcast_equals", "no_conflicts", "override"}, optional String indicating how to compare variables of the same name for potential conflicts: @@ -613,8 +710,7 @@ def combine_by_coords( refer to its values. If None, raises a ValueError if the passed Datasets do not create a complete hypercube. join : {"outer", "inner", "left", "right", "exact"}, optional - String indicating how to combine differing indexes - (excluding concat_dim) in objects + String indicating how to combine differing indexes in objects - "outer": use the union of object indexes - "inner": use the intersection of object indexes @@ -625,17 +721,23 @@ def combine_by_coords( - "override": if indexes are of same size, rewrite indexes to be those of the first object with that dimension. Indexes for the same dimension must have the same size in all objects. - combine_attrs : {"drop", "identical", "no_conflicts", "override"}, \ - default: "drop" - String indicating how to combine attrs of the objects being merged: + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ + "override"} or callable, default: "drop" + A callable or a string indicating how to combine attrs of the objects being + merged: - "drop": empty attrs on returned Dataset. - "identical": all attrs must be the same on every object. - "no_conflicts": attrs from all objects are combined, any that have the same name must also have the same value. + - "drop_conflicts": attrs from all objects are combined, any that have + the same name but different values are dropped. - "override": skip comparing and copy attrs from the first dataset to the result. + If a callable, it must expect a sequence of ``attrs`` dicts and a context object + as its only parameters. + Returns ------- combined : xarray.Dataset @@ -653,9 +755,6 @@ def combine_by_coords( they are concatenated based on the values in their dimension coordinates, not on their position in the list passed to `combine_by_coords`. - >>> import numpy as np - >>> import xarray as xr - >>> x1 = xr.Dataset( ... { ... "temperature": (("y", "x"), 20 * np.random.rand(6).reshape(2, 3)), @@ -680,7 +779,7 @@ def combine_by_coords( >>> x1 - Dimensions: (x: 3, y: 2) + Dimensions: (y: 2, x: 3) Coordinates: * y (y) int64 0 1 * x (x) int64 10 20 30 @@ -690,7 +789,7 @@ def combine_by_coords( >>> x2 - Dimensions: (x: 3, y: 2) + Dimensions: (y: 2, x: 3) Coordinates: * y (y) int64 2 3 * x (x) int64 10 20 30 @@ -700,7 +799,7 @@ def combine_by_coords( >>> x3 - Dimensions: (x: 3, y: 2) + Dimensions: (y: 2, x: 3) Coordinates: * y (y) int64 2 3 * x (x) int64 40 50 60 @@ -710,7 +809,7 @@ def combine_by_coords( >>> xr.combine_by_coords([x2, x1]) - Dimensions: (x: 3, y: 4) + Dimensions: (y: 4, x: 3) Coordinates: * y (y) int64 0 1 2 3 * x (x) int64 10 20 30 @@ -720,76 +819,91 @@ def combine_by_coords( >>> xr.combine_by_coords([x3, x1]) - Dimensions: (x: 6, y: 4) + Dimensions: (y: 4, x: 6) Coordinates: - * x (x) int64 10 20 30 40 50 60 * y (y) int64 0 1 2 3 + * x (x) int64 10 20 30 40 50 60 Data variables: temperature (y, x) float64 10.98 14.3 12.06 nan ... nan 18.89 10.44 8.293 precipitation (y, x) float64 0.4376 0.8918 0.9637 ... 0.5684 0.01879 0.6176 >>> xr.combine_by_coords([x3, x1], join="override") - Dimensions: (x: 3, y: 4) + Dimensions: (y: 2, x: 6) Coordinates: - * x (x) int64 10 20 30 - * y (y) int64 0 1 2 3 + * y (y) int64 0 1 + * x (x) int64 10 20 30 40 50 60 Data variables: - temperature (y, x) float64 10.98 14.3 12.06 10.9 ... 18.89 10.44 8.293 + temperature (y, x) float64 10.98 14.3 12.06 2.365 ... 18.89 10.44 8.293 precipitation (y, x) float64 0.4376 0.8918 0.9637 ... 0.5684 0.01879 0.6176 >>> xr.combine_by_coords([x1, x2, x3]) - Dimensions: (x: 6, y: 4) + Dimensions: (y: 4, x: 6) Coordinates: - * x (x) int64 10 20 30 40 50 60 * y (y) int64 0 1 2 3 + * x (x) int64 10 20 30 40 50 60 Data variables: temperature (y, x) float64 10.98 14.3 12.06 nan ... 18.89 10.44 8.293 precipitation (y, x) float64 0.4376 0.8918 0.9637 ... 0.5684 0.01879 0.6176 """ - # Group by data vars - sorted_datasets = sorted(datasets, key=vars_as_keys) - grouped_by_vars = itertools.groupby(sorted_datasets, key=vars_as_keys) - - # Perform the multidimensional combine on each group of data variables - # before merging back together - concatenated_grouped_by_data_vars = [] - for vars, datasets_with_same_vars in grouped_by_vars: - combined_ids, concat_dims = _infer_concat_order_from_coords( - list(datasets_with_same_vars) + # TODO remove after version 0.21, see PR4696 + if datasets is not None: + warnings.warn( + "The datasets argument has been renamed to `data_objects`." + " From 0.21 on passing a value for datasets will raise an error." ) + data_objects = datasets - if fill_value is None: - # check that datasets form complete hypercube - _check_shape_tile_ids(combined_ids) - else: - # check only that all datasets have same dimension depth for these - # vars - _check_dimension_depth_tile_ids(combined_ids) + if not data_objects: + return Dataset() - # Concatenate along all of concat_dims one by one to create single ds - concatenated = _combine_nd( - combined_ids, - concat_dims=concat_dims, + mixed_arrays_and_datasets = any( + isinstance(data_object, DataArray) and data_object.name is None + for data_object in data_objects + ) and any(isinstance(data_object, Dataset) for data_object in data_objects) + if mixed_arrays_and_datasets: + raise ValueError("Can't automatically combine datasets with unnamed arrays.") + + all_unnamed_data_arrays = all( + isinstance(data_object, DataArray) and data_object.name is None + for data_object in data_objects + ) + if all_unnamed_data_arrays: + unnamed_arrays = data_objects + temp_datasets = [data_array._to_temp_dataset() for data_array in unnamed_arrays] + + combined_temp_dataset = _combine_single_variable_hypercube( + temp_datasets, + fill_value=fill_value, data_vars=data_vars, coords=coords, compat=compat, - fill_value=fill_value, join=join, combine_attrs=combine_attrs, ) + return DataArray()._from_temp_dataset(combined_temp_dataset) - # Check the overall coordinates are monotonically increasing - for dim in concat_dims: - indexes = concatenated.indexes.get(dim) - if not (indexes.is_monotonic_increasing or indexes.is_monotonic_decreasing): - raise ValueError( - "Resulting object does not have monotonic" - " global indexes along dimension {}".format(dim) - ) - concatenated_grouped_by_data_vars.append(concatenated) + else: + # Group by data vars + sorted_datasets = sorted(data_objects, key=vars_as_keys) + grouped_by_vars = itertools.groupby(sorted_datasets, key=vars_as_keys) + + # Perform the multidimensional combine on each group of data variables + # before merging back together + concatenated_grouped_by_data_vars = [] + for vars, datasets_with_same_vars in grouped_by_vars: + concatenated = _combine_single_variable_hypercube( + list(datasets_with_same_vars), + fill_value=fill_value, + data_vars=data_vars, + coords=coords, + compat=compat, + join=join, + combine_attrs=combine_attrs, + ) + concatenated_grouped_by_data_vars.append(concatenated) return merge( concatenated_grouped_by_data_vars, diff --git a/xarray/core/common.py b/xarray/core/common.py index c5836c68759..ab822f576d3 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -16,13 +16,13 @@ Tuple, TypeVar, Union, + overload, ) import numpy as np import pandas as pd from . import dtypes, duck_array_ops, formatting, formatting_html, ops -from .arithmetic import SupportsArithmetic from .npcompat import DTypeLike from .options import OPTIONS, _get_keep_attrs from .pycompat import is_duck_dask_array @@ -35,6 +35,8 @@ if TYPE_CHECKING: from .dataarray import DataArray + from .dataset import Dataset + from .variable import Variable from .weighted import Weighted T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") @@ -55,7 +57,7 @@ def wrapped_func(self, dim=None, axis=None, skipna=None, **kwargs): else: - def wrapped_func(self, dim=None, axis=None, **kwargs): # type: ignore + def wrapped_func(self, dim=None, axis=None, **kwargs): # type: ignore[misc] return self.reduce(func, dim, axis, **kwargs) return wrapped_func @@ -94,7 +96,7 @@ def wrapped_func(self, dim=None, skipna=None, **kwargs): else: - def wrapped_func(self, dim=None, **kwargs): # type: ignore + def wrapped_func(self, dim=None, **kwargs): # type: ignore[misc] return self.reduce(func, dim, numeric_only=numeric_only, **kwargs) return wrapped_func @@ -118,7 +120,7 @@ def wrapped_func(self, dim=None, **kwargs): # type: ignore ).strip() -class AbstractArray(ImplementsArrayReduce): +class AbstractArray: """Shared base class for DataArray and Variable.""" __slots__ = () @@ -187,7 +189,7 @@ def sizes(self: Any) -> Mapping[Hashable, int]: Immutable. - See also + See Also -------- Dataset.sizes """ @@ -199,7 +201,7 @@ class AttrAccessMixin: __slots__ = () - def __init_subclass__(cls): + def __init_subclass__(cls, **kwargs): """Verify that all subclasses explicitly define ``__slots__``. If they don't, raise error in the core xarray module and a FutureWarning in third-party extensions. @@ -207,14 +209,15 @@ def __init_subclass__(cls): if not hasattr(object.__new__(cls), "__dict__"): pass elif cls.__module__.startswith("xarray."): - raise AttributeError("%s must explicitly define __slots__" % cls.__name__) + raise AttributeError(f"{cls.__name__} must explicitly define __slots__") else: cls.__setattr__ = cls._setattr_dict warnings.warn( - "xarray subclass %s should explicitly define __slots__" % cls.__name__, + f"xarray subclass {cls.__name__} should explicitly define __slots__", FutureWarning, stacklevel=2, ) + super().__init_subclass__(**kwargs) @property def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]: @@ -248,10 +251,9 @@ def _setattr_dict(self, name: str, value: Any) -> None: if name in self.__dict__: # Custom, non-slotted attr, or improperly assigned variable? warnings.warn( - "Setting attribute %r on a %r object. Explicitly define __slots__ " + f"Setting attribute {name!r} on a {type(self).__name__!r} object. Explicitly define __slots__ " "to suppress this warning for legitimate custom attributes and " - "raise an error when attempting variables assignments." - % (name, type(self).__name__), + "raise an error when attempting variables assignments.", FutureWarning, stacklevel=2, ) @@ -271,9 +273,8 @@ def __setattr__(self, name: str, value: Any) -> None: ): raise raise AttributeError( - "cannot set attribute %r on a %r object. Use __setitem__ style" + f"cannot set attribute {name!r} on a {type(self).__name__!r} object. Use __setitem__ style" "assignment (e.g., `ds['name'] = ...`) instead of assigning variables." - % (name, type(self).__name__) ) from e def __dir__(self) -> List[str]: @@ -335,15 +336,13 @@ def get_squeeze_dims( return dim -class DataWithCoords(SupportsArithmetic, AttrAccessMixin): +class DataWithCoords(AttrAccessMixin): """Shared base class for Dataset and DataArray.""" _close: Optional[Callable[[], None]] __slots__ = ("_close",) - _rolling_exp_cls = RollingExp - def squeeze( self, dim: Union[Hashable, Iterable[Hashable], None] = None, @@ -377,13 +376,35 @@ def squeeze( dims = get_squeeze_dims(self, dim, axis) return self.isel(drop=drop, **{d: 0 for d in dims}) + def clip(self, min=None, max=None, *, keep_attrs: bool = None): + """ + Return an array whose values are limited to ``[min, max]``. + At least one of max or min must be given. + + Refer to `numpy.clip` for full documentation. + + See Also + -------- + numpy.clip : equivalent function + """ + from .computation import apply_ufunc + + if keep_attrs is None: + # When this was a unary func, the default was True, so retaining the + # default. + keep_attrs = _get_keep_attrs(default=True) + + return apply_ufunc( + np.clip, self, min, max, keep_attrs=keep_attrs, dask="allowed" + ) + def get_index(self, key: Hashable) -> pd.Index: """Get an index for a dimension, with fall-back to a default RangeIndex""" if key not in self.dims: raise KeyError(key) try: - return self.indexes[key] + return self.xindexes[key].to_pandas_index() except KeyError: return pd.Index(range(self.sizes[key]), name=key) @@ -409,7 +430,6 @@ def assign_coords(self, coords=None, **coords_kwargs): defined and attached to an existing dimension using a tuple with the first element the dimension name and the second element the values for this new coordinate. - **coords_kwargs : optional The keyword arguments form of ``coords``. One of ``coords`` or ``coords_kwargs`` must be provided. @@ -470,7 +490,7 @@ def assign_coords(self, coords=None, **coords_kwargs): is possible, but you cannot reference other variables created within the same ``assign_coords`` call. - See also + See Also -------- Dataset.assign Dataset.swap_dims @@ -488,9 +508,9 @@ def assign_attrs(self, *args, **kwargs): Parameters ---------- - args + *args positional arguments passed into ``attrs.update``. - kwargs + **kwargs keyword arguments passed into ``attrs.update``. Returns @@ -498,7 +518,7 @@ def assign_attrs(self, *args, **kwargs): assigned : same type as caller A new object with the new attrs in addition to the existing data. - See also + See Also -------- Dataset.assign """ @@ -525,9 +545,9 @@ def pipe( Alternatively a ``(callable, data_keyword)`` tuple where ``data_keyword`` is a string indicating the keyword of ``callable`` that expects the xarray object. - args + *args positional arguments passed into ``func``. - kwargs + **kwargs a dictionary of keyword arguments passed into ``func``. Returns @@ -537,7 +557,6 @@ def pipe( Notes ----- - Use ``.pipe`` when chaining together functions that expect xarray or pandas objects, e.g., instead of writing @@ -561,9 +580,6 @@ def pipe( Examples -------- - - >>> import numpy as np - >>> import xarray as xr >>> x = xr.Dataset( ... { ... "temperature_c": ( @@ -635,7 +651,7 @@ def pipe( func, target = func if target in kwargs: raise ValueError( - "%s is both the pipe target and a keyword argument" % target + f"{target} is both the pipe target and a keyword argument" ) kwargs[target] = self return func(*args, **kwargs) @@ -805,7 +821,6 @@ def rolling( dim: Mapping[Hashable, int] = None, min_periods: int = None, center: Union[bool, Mapping[Hashable, bool]] = False, - keep_attrs: bool = None, **window_kwargs: int, ): """ @@ -813,7 +828,7 @@ def rolling( Parameters ---------- - dim: dict, optional + dim : dict, optional Mapping from the dimension name to create the rolling iterator along (e.g. `time`) to its moving window size. min_periods : int, default: None @@ -873,9 +888,7 @@ def rolling( """ dim = either_dict_or_kwargs(dim, window_kwargs, "rolling") - return self._rolling_cls( - self, dim, min_periods=min_periods, center=center, keep_attrs=keep_attrs - ) + return self._rolling_cls(self, dim, min_periods=min_periods, center=center) def rolling_exp( self, @@ -906,9 +919,17 @@ def rolling_exp( -------- core.rolling_exp.RollingExp """ + + if "keep_attrs" in window_kwargs: + warnings.warn( + "Passing ``keep_attrs`` to ``rolling_exp`` has no effect. Pass" + " ``keep_attrs`` directly to the applied function, e.g." + " ``rolling_exp(...).mean(keep_attrs=False)``." + ) + window = either_dict_or_kwargs(window, window_kwargs, "rolling_exp") - return self._rolling_exp_cls(self, window, window_type) + return RollingExp(self, window, window_type) def coarsen( self, @@ -916,7 +937,6 @@ def coarsen( boundary: str = "exact", side: Union[str, Mapping[Hashable, str]] = "left", coord_func: str = "mean", - keep_attrs: bool = None, **window_kwargs: int, ): """ @@ -934,10 +954,6 @@ def coarsen( coord_func : str or mapping of hashable to str, default: "mean" function (name) that is applied to the coordinates, or a mapping from coordinate name to function (name). - keep_attrs : bool, optional - If True, the object's attributes (`attrs`) will be copied from - the original object to the new one. If False (default), the new - object will be returned without attributes. Returns ------- @@ -981,8 +997,6 @@ def coarsen( core.rolling.DataArrayCoarsen core.rolling.DatasetCoarsen """ - if keep_attrs is None: - keep_attrs = _get_keep_attrs(default=False) dim = either_dict_or_kwargs(dim, window_kwargs, "coarsen") return self._coarsen_cls( @@ -991,7 +1005,6 @@ def coarsen( boundary=boundary, side=side, coord_func=coord_func, - keep_attrs=keep_attrs, ) def resample( @@ -1031,10 +1044,6 @@ def resample( loffset : timedelta or str, optional Offset used to adjust the resampled time labels. Some pandas date offset strings are supported. - keep_attrs : bool, optional - If True, the object's attributes (`attrs`) will be copied from - the original object to the new one. If False (default), the new - object will be returned without attributes. restore_coord_dims : bool, optional If True, also restore the dimension order of multi-dimensional coordinates. @@ -1101,7 +1110,6 @@ def resample( References ---------- - .. [1] http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases """ # TODO support non-string indexer after removing the old API. @@ -1110,8 +1118,12 @@ def resample( from .dataarray import DataArray from .resample import RESAMPLE_DIM - if keep_attrs is None: - keep_attrs = _get_keep_attrs(default=False) + if keep_attrs is not None: + warnings.warn( + "Passing ``keep_attrs`` to ``resample`` has no effect and will raise an" + " error in xarray 0.20. Pass ``keep_attrs`` directly to the applied" + " function, e.g. ``resample(...).mean(keep_attrs=True)``." + ) # note: the second argument (now 'skipna') use to be 'dim' if ( @@ -1141,7 +1153,8 @@ def resample( category=FutureWarning, ) - if isinstance(self.indexes[dim_name], CFTimeIndex): + # TODO (benbovy - flexible indexes): update when CFTimeIndex is an xarray Index subclass + if isinstance(self.xindexes[dim_name].to_pandas_index(), CFTimeIndex): from .resample_cftime import CFTimeGrouper grouper = CFTimeGrouper(freq, closed, label, base, loffset) @@ -1189,8 +1202,6 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): Examples -------- - - >>> import numpy as np >>> a = xr.DataArray(np.arange(25).reshape(5, 5), dims=("x", "y")) >>> a @@ -1235,7 +1246,7 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): [15., nan, nan, nan]]) Dimensions without coordinates: x, y - See also + See Also -------- numpy.where : corresponding numpy function where : equivalent function @@ -1253,8 +1264,7 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): if not isinstance(cond, (Dataset, DataArray)): raise TypeError( - "cond argument is %r but must be a %r or %r" - % (cond, Dataset, DataArray) + f"cond argument is {cond!r} but must be a {Dataset!r} or {DataArray!r}" ) # align so we can use integer indexing @@ -1386,14 +1396,13 @@ def isin(self, test_elements): Examples -------- - >>> array = xr.DataArray([1, 2, 3], dims="x") >>> array.isin([1, 3]) array([ True, False, True]) Dimensions without coordinates: x - See also + See Also -------- numpy.isin """ @@ -1452,7 +1461,6 @@ def astype( * 'same_kind' means only safe casts or casts within a kind, like float64 to float32, are allowed. * 'unsafe' means any data conversions may be done. - subok : bool, optional If True, then sub-classes will be passed-through, otherwise the returned array will be forced to be a base-class array. @@ -1477,7 +1485,7 @@ def astype( Make sure to only supply these arguments if the underlying array class supports them. - See also + See Also -------- numpy.ndarray.astype dask.array.Array.astype @@ -1508,7 +1516,26 @@ def __getitem__(self, value): raise NotImplementedError() -def full_like(other, fill_value, dtype: DTypeLike = None): +@overload +def full_like( + other: "Dataset", + fill_value, + dtype: Union[DTypeLike, Mapping[Hashable, DTypeLike]] = None, +) -> "Dataset": + ... + + +@overload +def full_like(other: "DataArray", fill_value, dtype: DTypeLike = None) -> "DataArray": + ... + + +@overload +def full_like(other: "Variable", fill_value, dtype: DTypeLike = None) -> "Variable": + ... + + +def full_like(other, fill_value, dtype=None): """Return a new object with the same shape and type as a given object. Parameters @@ -1533,9 +1560,6 @@ def full_like(other, fill_value, dtype: DTypeLike = None): Examples -------- - - >>> import numpy as np - >>> import xarray as xr >>> x = xr.DataArray( ... np.arange(6).reshape(2, 3), ... dims=["lat", "lon"], @@ -1609,9 +1633,8 @@ def full_like(other, fill_value, dtype: DTypeLike = None): a (x) bool True True True b (x) float64 2.0 2.0 2.0 - See also + See Also -------- - zeros_like ones_like @@ -1627,15 +1650,22 @@ def full_like(other, fill_value, dtype: DTypeLike = None): f"fill_value must be scalar or, for datasets, a dict-like. Received {fill_value} instead." ) + if not isinstance(other, Dataset) and isinstance(dtype, Mapping): + raise ValueError( + "'dtype' cannot be dict-like when passing a DataArray or Variable" + ) + if isinstance(other, Dataset): if not isinstance(fill_value, dict): fill_value = {k: fill_value for k in other.data_vars.keys()} - if not isinstance(dtype, dict): - dtype = {k: dtype for k in other.data_vars.keys()} + if not isinstance(dtype, Mapping): + dtype_ = {k: dtype for k in other.data_vars.keys()} + else: + dtype_ = dtype data_vars = { - k: _full_like_variable(v, fill_value.get(k, dtypes.NA), dtype.get(k, None)) + k: _full_like_variable(v, fill_value.get(k, dtypes.NA), dtype_.get(k, None)) for k, v in other.data_vars.items() } return Dataset(data_vars, coords=other.coords, attrs=other.attrs) @@ -1692,9 +1722,6 @@ def zeros_like(other, dtype: DTypeLike = None): Examples -------- - - >>> import numpy as np - >>> import xarray as xr >>> x = xr.DataArray( ... np.arange(6).reshape(2, 3), ... dims=["lat", "lon"], @@ -1724,9 +1751,8 @@ def zeros_like(other, dtype: DTypeLike = None): * lat (lat) int64 1 2 * lon (lon) int64 0 1 2 - See also + See Also -------- - ones_like full_like @@ -1752,9 +1778,6 @@ def ones_like(other, dtype: DTypeLike = None): Examples -------- - - >>> import numpy as np - >>> import xarray as xr >>> x = xr.DataArray( ... np.arange(6).reshape(2, 3), ... dims=["lat", "lon"], @@ -1776,9 +1799,8 @@ def ones_like(other, dtype: DTypeLike = None): * lat (lat) int64 1 2 * lon (lon) int64 0 1 2 - See also + See Also -------- - zeros_like full_like diff --git a/xarray/core/computation.py b/xarray/core/computation.py index e0d9ff4b218..cd9e22d90db 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1,12 +1,13 @@ """ Functions for applying functions that act on arrays to xarray's labeled data. """ +from __future__ import annotations + import functools import itertools import operator import warnings from collections import Counter -from distutils.version import LooseVersion from typing import ( TYPE_CHECKING, AbstractSet, @@ -20,6 +21,7 @@ Optional, Sequence, Tuple, + TypeVar, Union, ) @@ -27,29 +29,37 @@ from . import dtypes, duck_array_ops, utils from .alignment import align, deep_align -from .merge import merge_coordinates_without_align -from .options import OPTIONS -from .pycompat import is_duck_dask_array +from .merge import merge_attrs, merge_coordinates_without_align +from .options import OPTIONS, _get_keep_attrs +from .pycompat import dask_version, is_duck_dask_array from .utils import is_dict_like from .variable import Variable if TYPE_CHECKING: from .coordinates import Coordinates # noqa + from .dataarray import DataArray from .dataset import Dataset + T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset) + _NO_FILL_VALUE = utils.ReprObject("") _DEFAULT_NAME = utils.ReprObject("") _JOINS_WITHOUT_FILL_VALUES = frozenset({"inner", "exact"}) def _first_of_type(args, kind): - """ Return either first object of type 'kind' or raise if not found. """ + """Return either first object of type 'kind' or raise if not found.""" for arg in args: if isinstance(arg, kind): return arg raise ValueError("This should be unreachable.") +def _all_of_type(args, kind): + """Return all objects of type 'kind'""" + return [arg for arg in args if isinstance(arg, kind)] + + class _UFuncSignature: """Core dimensions signature for a given function. @@ -202,7 +212,10 @@ def _get_coords_list(args) -> List["Coordinates"]: def build_output_coords( - args: list, signature: _UFuncSignature, exclude_dims: AbstractSet = frozenset() + args: list, + signature: _UFuncSignature, + exclude_dims: AbstractSet = frozenset(), + combine_attrs: str = "override", ) -> "List[Dict[Any, Variable]]": """Build output coordinates for an operation. @@ -230,7 +243,7 @@ def build_output_coords( else: # TODO: save these merged indexes, instead of re-computing them later merged_vars, unused_indexes = merge_coordinates_without_align( - coords_list, exclude_dims=exclude_dims + coords_list, exclude_dims=exclude_dims, combine_attrs=combine_attrs ) output_coords = [] @@ -248,7 +261,12 @@ def build_output_coords( def apply_dataarray_vfunc( - func, *args, signature, join="inner", exclude_dims=frozenset(), keep_attrs=False + func, + *args, + signature, + join="inner", + exclude_dims=frozenset(), + keep_attrs="override", ): """Apply a variable level function over DataArray, Variable and/or ndarray objects. @@ -260,12 +278,16 @@ def apply_dataarray_vfunc( args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False ) - if keep_attrs: + objs = _all_of_type(args, DataArray) + + if keep_attrs == "drop": + name = result_name(args) + else: first_obj = _first_of_type(args, DataArray) name = first_obj.name - else: - name = result_name(args) - result_coords = build_output_coords(args, signature, exclude_dims) + result_coords = build_output_coords( + args, signature, exclude_dims, combine_attrs=keep_attrs + ) data_vars = [getattr(a, "variable", a) for a in args] result_var = func(*data_vars) @@ -279,13 +301,12 @@ def apply_dataarray_vfunc( (coords,) = result_coords out = DataArray(result_var, coords, name=name, fastpath=True) - if keep_attrs: - if isinstance(out, tuple): - for da in out: - # This is adding attrs in place - da._copy_attrs_from(first_obj) - else: - out._copy_attrs_from(first_obj) + attrs = merge_attrs([x.attrs for x in objs], combine_attrs=keep_attrs) + if isinstance(out, tuple): + for da in out: + da.attrs = attrs + else: + out.attrs = attrs return out @@ -307,12 +328,12 @@ def assert_and_return_exact_match(all_keys): if keys != first_keys: raise ValueError( "exact match required for all data variable names, " - "but %r != %r" % (keys, first_keys) + f"but {keys!r} != {first_keys!r}" ) return first_keys -_JOINERS = { +_JOINERS: Dict[str, Callable] = { "inner": ordered_set_intersection, "outer": ordered_set_union, "left": operator.itemgetter(0), @@ -400,7 +421,7 @@ def apply_dataset_vfunc( dataset_join="exact", fill_value=_NO_FILL_VALUE, exclude_dims=frozenset(), - keep_attrs=False, + keep_attrs="override", ): """Apply a variable level function over Dataset, dict of DataArray, DataArray, Variable and/or ndarray objects. @@ -414,15 +435,16 @@ def apply_dataset_vfunc( "dataset_fill_value argument." ) - if keep_attrs: - first_obj = _first_of_type(args, Dataset) + objs = _all_of_type(args, Dataset) if len(args) > 1: args = deep_align( args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False ) - list_of_coords = build_output_coords(args, signature, exclude_dims) + list_of_coords = build_output_coords( + args, signature, exclude_dims, combine_attrs=keep_attrs + ) args = [getattr(arg, "data_vars", arg) for arg in args] result_vars = apply_dict_of_variables_vfunc( @@ -435,13 +457,13 @@ def apply_dataset_vfunc( (coord_vars,) = list_of_coords out = _fast_dataset(result_vars, coord_vars) - if keep_attrs: - if isinstance(out, tuple): - for ds in out: - # This is adding attrs in place - ds._copy_attrs_from(first_obj) - else: - out._copy_attrs_from(first_obj) + attrs = merge_attrs([x.attrs for x in objs], combine_attrs=keep_attrs) + if isinstance(out, tuple): + for ds in out: + ds.attrs = attrs + else: + out.attrs = attrs + return out @@ -516,7 +538,7 @@ def unified_dim_sizes( if len(set(var.dims)) < len(var.dims): raise ValueError( "broadcasting cannot handle duplicate " - "dimensions on a variable: %r" % list(var.dims) + f"dimensions on a variable: {list(var.dims)}" ) for dim, size in zip(var.dims, var.shape): if dim not in exclude_dims: @@ -526,7 +548,7 @@ def unified_dim_sizes( raise ValueError( "operands cannot be broadcast together " "with mismatched lengths for dimension " - "%r: %s vs %s" % (dim, dim_sizes[dim], size) + f"{dim}: {dim_sizes[dim]} vs {size}" ) return dim_sizes @@ -563,8 +585,8 @@ def broadcast_compat_data( if unexpected_dims: raise ValueError( "operand to apply_ufunc encountered unexpected " - "dimensions %r on an input variable: these are core " - "dimensions on other input or output variables" % unexpected_dims + f"dimensions {unexpected_dims!r} on an input variable: these are core " + "dimensions on other input or output variables" ) # for consistency with numpy, keep broadcast dimensions to the left @@ -575,7 +597,7 @@ def broadcast_compat_data( data = duck_array_ops.transpose(data, order) if new_dims != reordered_dims: - key_parts = [] + key_parts: List[Optional[slice]] = [] for dim in new_dims: if dim in set_old_dims: key_parts.append(SLICE_NONE) @@ -609,14 +631,12 @@ def apply_variable_ufunc( dask="forbidden", output_dtypes=None, vectorize=False, - keep_attrs=False, + keep_attrs="override", dask_gufunc_kwargs=None, ): """Apply a ndarray level function over Variable and/or ndarray objects.""" from .variable import Variable, as_compatible_data - first_obj = _first_of_type(args, Variable) - dim_sizes = unified_dim_sizes( (a for a in args if hasattr(a, "dims")), exclude_dims=exclude_dims ) @@ -663,7 +683,7 @@ def apply_variable_ufunc( "apply_ufunc with dask='parallelized' consists of " "multiple chunks, but is also a core dimension. To " "fix, either rechunk into a single dask array chunk along " - f"this dimension, i.e., ``.chunk({dim}: -1)``, or " + f"this dimension, i.e., ``.chunk(dict({dim}=-1))``, or " "pass ``allow_rechunk=True`` in ``dask_gufunc_kwargs`` " "but beware that this may significantly increase memory usage." ) @@ -700,9 +720,7 @@ def func(*arrays): # todo: covers for https://github.com/dask/dask/pull/6207 # remove when minimal dask version >= 2.17.0 - from dask import __version__ as dask_version - - if LooseVersion(dask_version) < LooseVersion("2.17.0"): + if dask_version < "2.17.0": if signature.num_outputs > 1: res = tuple(res) @@ -736,6 +754,12 @@ def func(*arrays): ) ) + objs = _all_of_type(args, Variable) + attrs = merge_attrs( + [obj.attrs for obj in objs], + combine_attrs=keep_attrs, + ) + output = [] for dims, data in zip(output_dims, result_data): data = as_compatible_data(data) @@ -758,8 +782,7 @@ def func(*arrays): ) ) - if keep_attrs: - var.attrs.update(first_obj.attrs) + var.attrs = attrs output.append(var) if signature.num_outputs == 1: @@ -801,7 +824,7 @@ def apply_ufunc( join: str = "exact", dataset_join: str = "exact", dataset_fill_value: object = _NO_FILL_VALUE, - keep_attrs: bool = False, + keep_attrs: Union[bool, str] = None, kwargs: Mapping = None, dask: str = "forbidden", output_dtypes: Sequence = None, @@ -885,11 +908,11 @@ def apply_ufunc( Value used in place of missing variables on Dataset inputs when the datasets do not share the exact same ``data_vars``. Required if ``dataset_join not in {'inner', 'exact'}``, otherwise ignored. - keep_attrs: bool, optional + keep_attrs : bool, optional Whether to copy attributes from the first argument to the output. - kwargs: dict, optional + kwargs : dict, optional Optional keyword arguments passed directly on to call ``func``. - dask: {"forbidden", "allowed", "parallelized"}, default: "forbidden" + dask : {"forbidden", "allowed", "parallelized"}, default: "forbidden" How to handle applying to objects containing lazy data in the form of dask arrays: @@ -923,9 +946,16 @@ def apply_ufunc( Single value or tuple of Dataset, DataArray, Variable, dask.array.Array or numpy.ndarray, the first type on that list to appear on an input. + Notes + ----- + This function is designed for the more common case where ``func`` can work on numpy + arrays. If ``func`` needs to manipulate a whole xarray object subset to each block + it is possible to use :py:func:`xarray.map_blocks`. + + Note that due to the overhead ``map_blocks`` is considerably slower than ``apply_ufunc``. + Examples -------- - Calculate the vector magnitude of two arguments: >>> def magnitude(a, b): @@ -959,51 +989,58 @@ def apply_ufunc( Other examples of how you could use ``apply_ufunc`` to write functions to (very nearly) replicate existing xarray functionality: - Compute the mean (``.mean``) over one dimension:: - - def mean(obj, dim): - # note: apply always moves core dimensions to the end - return apply_ufunc(np.mean, obj, - input_core_dims=[[dim]], - kwargs={'axis': -1}) + Compute the mean (``.mean``) over one dimension: - Inner product over a specific dimension (like ``xr.dot``):: - - def _inner(x, y): - result = np.matmul(x[..., np.newaxis, :], y[..., :, np.newaxis]) - return result[..., 0, 0] + >>> def mean(obj, dim): + ... # note: apply always moves core dimensions to the end + ... return apply_ufunc( + ... np.mean, obj, input_core_dims=[[dim]], kwargs={"axis": -1} + ... ) + ... - def inner_product(a, b, dim): - return apply_ufunc(_inner, a, b, input_core_dims=[[dim], [dim]]) + Inner product over a specific dimension (like ``xr.dot``): - Stack objects along a new dimension (like ``xr.concat``):: + >>> def _inner(x, y): + ... result = np.matmul(x[..., np.newaxis, :], y[..., :, np.newaxis]) + ... return result[..., 0, 0] + ... + >>> def inner_product(a, b, dim): + ... return apply_ufunc(_inner, a, b, input_core_dims=[[dim], [dim]]) + ... - def stack(objects, dim, new_coord): - # note: this version does not stack coordinates - func = lambda *x: np.stack(x, axis=-1) - result = apply_ufunc(func, *objects, - output_core_dims=[[dim]], - join='outer', - dataset_fill_value=np.nan) - result[dim] = new_coord - return result + Stack objects along a new dimension (like ``xr.concat``): + + >>> def stack(objects, dim, new_coord): + ... # note: this version does not stack coordinates + ... func = lambda *x: np.stack(x, axis=-1) + ... result = apply_ufunc( + ... func, + ... *objects, + ... output_core_dims=[[dim]], + ... join="outer", + ... dataset_fill_value=np.nan + ... ) + ... result[dim] = new_coord + ... return result + ... If your function is not vectorized but can be applied only to core dimensions, you can use ``vectorize=True`` to turn into a vectorized function. This wraps :py:func:`numpy.vectorize`, so the operation isn't terribly fast. Here we'll use it to calculate the distance between empirical samples from two probability distributions, using a scipy - function that needs to be applied to vectors:: - - import scipy.stats - - def earth_mover_distance(first_samples, - second_samples, - dim='ensemble'): - return apply_ufunc(scipy.stats.wasserstein_distance, - first_samples, second_samples, - input_core_dims=[[dim], [dim]], - vectorize=True) + function that needs to be applied to vectors: + + >>> import scipy.stats + >>> def earth_mover_distance(first_samples, second_samples, dim="ensemble"): + ... return apply_ufunc( + ... scipy.stats.wasserstein_distance, + ... first_samples, + ... second_samples, + ... input_core_dims=[[dim], [dim]], + ... vectorize=True, + ... ) + ... Most of NumPy's builtin functions already broadcast their inputs appropriately for use in `apply`. You may find helper functions such as @@ -1011,11 +1048,13 @@ def earth_mover_distance(first_samples, works well with numba's vectorize and guvectorize. Further explanation with examples are provided in the xarray documentation [3]_. - See also + See Also -------- numpy.broadcast_arrays numba.vectorize numba.guvectorize + dask.array.apply_gufunc + xarray.map_blocks References ---------- @@ -1082,6 +1121,12 @@ def earth_mover_distance(first_samples, if kwargs: func = functools.partial(func, **kwargs) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + + if isinstance(keep_attrs, bool): + keep_attrs = "override" if keep_attrs else "drop" + variables_vfunc = functools.partial( apply_variable_ufunc, func, @@ -1162,10 +1207,10 @@ def cov(da_a, da_b, dim=None, ddof=1): ------- covariance : DataArray - See also + See Also -------- pandas.Series.cov : corresponding pandas function - xarray.corr: respective function to calculate correlation + xarray.corr : respective function to calculate correlation Examples -------- @@ -1240,7 +1285,7 @@ def corr(da_a, da_b, dim=None): ------- correlation: DataArray - See also + See Also -------- pandas.Series.corr : corresponding pandas function xarray.cov : underlying covariance function @@ -1310,12 +1355,23 @@ def _cov_corr(da_a, da_b, dim=None, ddof=0, method=None): # 2. Ignore the nans valid_values = da_a.notnull() & da_b.notnull() + valid_count = valid_values.sum(dim) - ddof - if not valid_values.all(): - da_a = da_a.where(valid_values) - da_b = da_b.where(valid_values) + def _get_valid_values(da, other): + """ + Function to lazily mask da_a and da_b + following a similar approach to + https://github.com/pydata/xarray/pull/4559 + """ + missing_vals = np.logical_or(da.isnull(), other.isnull()) + if missing_vals.any(): + da = da.where(~missing_vals) + return da + else: + return da - valid_count = valid_values.sum(dim) - ddof + da_a = da_a.map_blocks(_get_valid_values, args=[da_b]) + da_b = da_b.map_blocks(_get_valid_values, args=[da_a]) # 3. Detrend along the given dim demeaned_da_a = da_a - da_a.mean(dim=dim) @@ -1346,7 +1402,7 @@ def dot(*arrays, dims=None, **kwargs): Parameters ---------- - arrays : DataArray or Variable + *arrays : DataArray or Variable Arrays to compute. dims : ..., str or tuple of str, optional Which dimensions to sum over. Ellipsis ('...') sums over all dimensions. @@ -1361,9 +1417,6 @@ def dot(*arrays, dims=None, **kwargs): Examples -------- - - >>> import numpy as np - >>> import xarray as xr >>> da_a = xr.DataArray(np.arange(3 * 2).reshape(3, 2), dims=["a", "b"]) >>> da_b = xr.DataArray(np.arange(3 * 2 * 2).reshape(3, 2, 2), dims=["a", "b", "c"]) >>> da_c = xr.DataArray(np.arange(2 * 3).reshape(2, 3), dims=["c", "d"]) @@ -1496,7 +1549,6 @@ def where(cond, x, y): All dimension coordinates on `x` and `y` must be aligned with each other and with `cond`. - Parameters ---------- cond : scalar, array, Variable, DataArray or Dataset @@ -1514,8 +1566,6 @@ def where(cond, x, y): Examples -------- - >>> import xarray as xr - >>> import numpy as np >>> x = xr.DataArray( ... 0.1 * np.arange(10), ... dims=["lat"], @@ -1566,10 +1616,11 @@ def where(cond, x, y): [0, 0]]) Dimensions without coordinates: x, y - See also + See Also -------- numpy.where : corresponding numpy function - Dataset.where, DataArray.where : equivalent methods + Dataset.where, DataArray.where : + equivalent methods """ # alignment for three arguments is complicated, so don't support it yet return apply_ufunc( @@ -1595,7 +1646,7 @@ def polyval(coord, coeffs, degree_dim="degree"): degree_dim : str, default: "degree" Name of the polynomial degree dimension in `coeffs`. - See also + See Also -------- xarray.DataArray.polyfit numpy.polyval @@ -1676,3 +1727,61 @@ def _calc_idxminmax( res.attrs = indx.attrs return res + + +def unify_chunks(*objects: T_DSorDA) -> Tuple[T_DSorDA, ...]: + """ + Given any number of Dataset and/or DataArray objects, returns + new objects with unified chunk size along all chunked dimensions. + + Returns + ------- + unified (DataArray or Dataset) – Tuple of objects with the same type as + *objects with consistent chunk sizes for all dask-array variables + + See Also + -------- + dask.array.core.unify_chunks + """ + from .dataarray import DataArray + + # Convert all objects to datasets + datasets = [ + obj._to_temp_dataset() if isinstance(obj, DataArray) else obj.copy() + for obj in objects + ] + + # Get argumets to pass into dask.array.core.unify_chunks + unify_chunks_args = [] + sizes: dict[Hashable, int] = {} + for ds in datasets: + for v in ds._variables.values(): + if v.chunks is not None: + # Check that sizes match across different datasets + for dim, size in v.sizes.items(): + try: + if sizes[dim] != size: + raise ValueError( + f"Dimension {dim!r} size mismatch: {sizes[dim]} != {size}" + ) + except KeyError: + sizes[dim] = size + unify_chunks_args += [v._data, v._dims] + + # No dask arrays: Return inputs + if not unify_chunks_args: + return objects + + # Run dask.array.core.unify_chunks + from dask.array.core import unify_chunks + + _, dask_data = unify_chunks(*unify_chunks_args) + dask_data_iter = iter(dask_data) + out = [] + for obj, ds in zip(objects, datasets): + for k, v in ds._variables.items(): + if v.chunks is not None: + ds._variables[k] = v.copy(data=next(dask_data_iter)) + out.append(obj._from_temp_dataset(ds) if isinstance(obj, DataArray) else ds) + + return tuple(out) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 5cda5aa903c..7a15685fd56 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -142,17 +142,23 @@ def concat( - "override": if indexes are of same size, rewrite indexes to be those of the first object with that dimension. Indexes for the same dimension must have the same size in all objects. - combine_attrs : {"drop", "identical", "no_conflicts", "override"}, \ - default: "override" - String indicating how to combine attrs of the objects being merged: + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ + "override"} or callable, default: "override" + A callable or a string indicating how to combine attrs of the objects being + merged: - "drop": empty attrs on returned Dataset. - "identical": all attrs must be the same on every object. - "no_conflicts": attrs from all objects are combined, any that have the same name must also have the same value. + - "drop_conflicts": attrs from all objects are combined, any that have + the same name but different values are dropped. - "override": skip comparing and copy attrs from the first dataset to the result. + If a callable, it must expect a sequence of ``attrs`` dicts and a context object + as its only parameters. + Returns ------- concatenated : type of objs @@ -221,8 +227,7 @@ def concat( if compat not in _VALID_COMPAT: raise ValueError( - "compat=%r invalid: must be 'broadcast_equals', 'equals', 'identical', 'no_conflicts' or 'override'" - % compat + f"compat={compat!r} invalid: must be 'broadcast_equals', 'equals', 'identical', 'no_conflicts' or 'override'" ) if isinstance(first_obj, DataArray): @@ -232,7 +237,7 @@ def concat( else: raise TypeError( "can only concatenate xarray Dataset and DataArray " - "objects, got %s" % type(first_obj) + f"objects, got {type(first_obj)}" ) return f( objs, dim, data_vars, coords, compat, positions, fill_value, join, combine_attrs @@ -291,18 +296,16 @@ def process_subset_opt(opt, subset): if opt == "different": if compat == "override": raise ValueError( - "Cannot specify both %s='different' and compat='override'." - % subset + f"Cannot specify both {subset}='different' and compat='override'." ) # all nonindexes that are not the same in each dataset for k in getattr(datasets[0], subset): if k not in concat_over: equals[k] = None - variables = [] - for ds in datasets: - if k in ds.variables: - variables.append(ds.variables[k]) + variables = [ + ds.variables[k] for ds in datasets if k in ds.variables + ] if len(variables) == 1: # coords="different" doesn't make sense when only one object @@ -365,12 +368,12 @@ def process_subset_opt(opt, subset): if subset == "coords": raise ValueError( "some variables in coords are not coordinates on " - "the first dataset: %s" % (invalid_vars,) + f"the first dataset: {invalid_vars}" ) else: raise ValueError( "some variables in data_vars are not data variables " - "on the first dataset: %s" % (invalid_vars,) + f"on the first dataset: {invalid_vars}" ) concat_over.update(opt) @@ -423,6 +426,13 @@ def _dataset_concat( """ from .dataset import Dataset + datasets = list(datasets) + + if not all(isinstance(dataset, Dataset) for dataset in datasets): + raise TypeError( + "The elements in the input list need to be either all 'Dataset's or all 'DataArray's" + ) + dim, coord = _calc_concat_dim_coord(dim) # Make sure we're working on a copy (we'll be loading variables) datasets = [ds.copy() for ds in datasets] @@ -437,7 +447,7 @@ def _dataset_concat( both_data_and_coords = coord_names & data_names if both_data_and_coords: raise ValueError( - "%r is a coordinate in some datasets but not others." % both_data_and_coords + f"{both_data_and_coords!r} is a coordinate in some datasets but not others." ) # we don't want the concat dimension in the result dataset yet dim_coords.pop(dim, None) @@ -505,8 +515,8 @@ def ensure_common_dims(vars): try: vars = ensure_common_dims([ds[k].variable for ds in datasets]) except KeyError: - raise ValueError("%r is not present in all datasets." % k) - combined = concat_vars(vars, dim, positions) + raise ValueError(f"{k!r} is not present in all datasets.") + combined = concat_vars(vars, dim, positions, combine_attrs=combine_attrs) assert isinstance(combined, Variable) result_vars[k] = combined elif k in result_vars: @@ -517,8 +527,7 @@ def ensure_common_dims(vars): absent_coord_names = coord_names - set(result.variables) if absent_coord_names: raise ValueError( - "Variables %r are coordinates in some datasets but not others." - % absent_coord_names + f"Variables {absent_coord_names!r} are coordinates in some datasets but not others." ) result = result.set_coords(coord_names) result.encoding = result_encoding @@ -543,8 +552,15 @@ def _dataarray_concat( join: str = "outer", combine_attrs: str = "override", ) -> "DataArray": + from .dataarray import DataArray + arrays = list(arrays) + if not all(isinstance(array, DataArray) for array in arrays): + raise TypeError( + "The elements in the input list need to be either all 'Dataset's or all 'DataArray's" + ) + if data_vars != "all": raise ValueError( "data_vars is not a valid argument when concatenating DataArray objects" @@ -570,7 +586,7 @@ def _dataarray_concat( positions, fill_value=fill_value, join=join, - combine_attrs="drop", + combine_attrs=combine_attrs, ) merged_attrs = merge_attrs([da.attrs for da in arrays], combine_attrs) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 37c462f79f4..767b76d0d12 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -13,10 +13,11 @@ cast, ) +import numpy as np import pandas as pd from . import formatting, indexing -from .indexes import Indexes +from .indexes import Index, Indexes from .merge import merge_coordinates_without_align, merge_coords from .utils import Frozen, ReprObject, either_dict_or_kwargs from .variable import Variable @@ -49,7 +50,11 @@ def dims(self) -> Union[Mapping[Hashable, int], Tuple[Hashable, ...]]: @property def indexes(self) -> Indexes: - return self._data.indexes # type: ignore + return self._data.indexes # type: ignore[attr-defined] + + @property + def xindexes(self) -> Indexes: + return self._data.xindexes # type: ignore[attr-defined] @property def variables(self): @@ -104,26 +109,70 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index: raise ValueError("no valid index for a 0-dimensional object") elif len(ordered_dims) == 1: (dim,) = ordered_dims - return self._data.get_index(dim) # type: ignore + return self._data.get_index(dim) # type: ignore[attr-defined] else: - indexes = [self._data.get_index(k) for k in ordered_dims] # type: ignore - names = list(ordered_dims) - return pd.MultiIndex.from_product(indexes, names=names) + indexes = [ + self._data.get_index(k) for k in ordered_dims # type: ignore[attr-defined] + ] + + # compute the sizes of the repeat and tile for the cartesian product + # (taken from pandas.core.reshape.util) + index_lengths = np.fromiter( + (len(index) for index in indexes), dtype=np.intp + ) + cumprod_lengths = np.cumproduct(index_lengths) + + if cumprod_lengths[-1] == 0: + # if any factor is empty, the cartesian product is empty + repeat_counts = np.zeros_like(cumprod_lengths) + + else: + # sizes of the repeats + repeat_counts = cumprod_lengths[-1] / cumprod_lengths + # sizes of the tiles + tile_counts = np.roll(cumprod_lengths, 1) + tile_counts[0] = 1 + + # loop over the indexes + # for each MultiIndex or Index compute the cartesian product of the codes + + code_list = [] + level_list = [] + names = [] + + for i, index in enumerate(indexes): + if isinstance(index, pd.MultiIndex): + codes, levels = index.codes, index.levels + else: + code, level = pd.factorize(index) + codes = [code] + levels = [level] + + # compute the cartesian product + code_list += [ + np.tile(np.repeat(code, repeat_counts[i]), tile_counts[i]) + for code in codes + ] + level_list += levels + names += index.names + + return pd.MultiIndex(level_list, code_list, names=names) def update(self, other: Mapping[Hashable, Any]) -> None: other_vars = getattr(other, "variables", other) coords, indexes = merge_coords( - [self.variables, other_vars], priority_arg=1, indexes=self.indexes + [self.variables, other_vars], priority_arg=1, indexes=self.xindexes ) self._update_coords(coords, indexes) - def _merge_raw(self, other): + def _merge_raw(self, other, reflexive): """For use with binary arithmetic.""" if other is None: variables = dict(self.variables) - indexes = dict(self.indexes) + indexes = dict(self.xindexes) else: - variables, indexes = merge_coordinates_without_align([self, other]) + coord_list = [self, other] if not reflexive else [other, self] + variables, indexes = merge_coordinates_without_align(coord_list) return variables, indexes @contextmanager @@ -135,7 +184,9 @@ def _merge_inplace(self, other): # don't include indexes in prioritized, because we didn't align # first and we want indexes to be checked prioritized = { - k: (v, None) for k, v in self.variables.items() if k not in self.indexes + k: (v, None) + for k, v in self.variables.items() + if k not in self.xindexes } variables, indexes = merge_coordinates_without_align( [self, other], prioritized @@ -175,10 +226,9 @@ def merge(self, other: "Coordinates") -> "Dataset": coords, indexes = merge_coordinates_without_align([self, other]) coord_names = set(coords) - merged = Dataset._construct_direct( + return Dataset._construct_direct( variables=coords, coord_names=coord_names, indexes=indexes ) - return merged class DatasetCoordinates(Coordinates): @@ -220,7 +270,7 @@ def to_dataset(self) -> "Dataset": return self._data._copy_listed(names) def _update_coords( - self, coords: Dict[Hashable, Variable], indexes: Mapping[Hashable, pd.Index] + self, coords: Dict[Hashable, Variable], indexes: Mapping[Hashable, Index] ) -> None: from .dataset import calculate_dimensions @@ -240,7 +290,7 @@ def _update_coords( # TODO(shoyer): once ._indexes is always populated by a dict, modify # it to update inplace instead. - original_indexes = dict(self._data.indexes) + original_indexes = dict(self._data.xindexes) original_indexes.update(indexes) self._data._indexes = original_indexes @@ -251,7 +301,7 @@ def __delitem__(self, key: Hashable) -> None: raise KeyError(f"{key!r} is not a coordinate variable.") def _ipython_key_completions_(self): - """Provide method for the key-autocompletions in IPython. """ + """Provide method for the key-autocompletions in IPython.""" return [ key for key in self._data._ipython_key_completions_() @@ -283,7 +333,7 @@ def __getitem__(self, key: Hashable) -> "DataArray": return self._data._getitem_coord(key) def _update_coords( - self, coords: Dict[Hashable, Variable], indexes: Mapping[Hashable, pd.Index] + self, coords: Dict[Hashable, Variable], indexes: Mapping[Hashable, Index] ) -> None: from .dataset import calculate_dimensions @@ -298,7 +348,7 @@ def _update_coords( # TODO(shoyer): once ._indexes is always populated by a dict, modify # it to update inplace instead. - original_indexes = dict(self._data.indexes) + original_indexes = dict(self._data.xindexes) original_indexes.update(indexes) self._data._indexes = original_indexes @@ -313,15 +363,15 @@ def to_dataset(self) -> "Dataset": return Dataset._construct_direct(coords, set(coords)) def __delitem__(self, key: Hashable) -> None: - if key in self: - del self._data._coords[key] - if self._data._indexes is not None and key in self._data._indexes: - del self._data._indexes[key] - else: + if key not in self: raise KeyError(f"{key!r} is not a coordinate variable.") + del self._data._coords[key] + if self._data._indexes is not None and key in self._data._indexes: + del self._data._indexes[key] + def _ipython_key_completions_(self): - """Provide method for the key-autocompletions in IPython. """ + """Provide method for the key-autocompletions in IPython.""" return self._data._ipython_key_completions_() @@ -335,14 +385,11 @@ def assert_coordinate_consistent( """ for k in obj.dims: # make sure there are no conflict in dimension coordinates - if k in coords and k in obj.coords: - if not coords[k].equals(obj[k].variable): - raise IndexError( - "dimension coordinate {!r} conflicts between " - "indexed and indexing objects:\n{}\nvs.\n{}".format( - k, obj[k], coords[k] - ) - ) + if k in coords and k in obj.coords and not coords[k].equals(obj[k].variable): + raise IndexError( + f"dimension coordinate {k!r} conflicts between " + f"indexed and indexing objects:\n{obj[k]}\nvs.\n{coords[k]}" + ) def remap_label_indexers( diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py index ce15e01fb12..c0b99d430d4 100644 --- a/xarray/core/dask_array_compat.py +++ b/xarray/core/dask_array_compat.py @@ -1,14 +1,12 @@ import warnings -from distutils.version import LooseVersion -from typing import Iterable import numpy as np +from .pycompat import dask_version + try: import dask.array as da - from dask import __version__ as dask_version except ImportError: - dask_version = "0.0.0" da = None @@ -58,38 +56,130 @@ def pad(array, pad_width, mode="constant", **kwargs): return padded -if LooseVersion(dask_version) > LooseVersion("2.9.0"): - nanmedian = da.nanmedian +if dask_version > "2.30.0": + ensure_minimum_chunksize = da.overlap.ensure_minimum_chunksize else: - def nanmedian(a, axis=None, keepdims=False): - """ - This works by automatically chunking the reduced axes to a single chunk - and then calling ``numpy.nanmedian`` function across the remaining dimensions + # copied from dask + def ensure_minimum_chunksize(size, chunks): + """Determine new chunks to ensure that every chunk >= size + + Parameters + ---------- + size : int + The maximum size of any chunk. + chunks : tuple + Chunks along one axis, e.g. ``(3, 3, 2)`` + + Examples + -------- + >>> ensure_minimum_chunksize(10, (20, 20, 1)) + (20, 11, 10) + >>> ensure_minimum_chunksize(3, (1, 1, 3)) + (5,) + + See Also + -------- + overlap """ - - if axis is None: - raise NotImplementedError( - "The da.nanmedian function only works along an axis. " - "The full algorithm is difficult to do in parallel" + if size <= min(chunks): + return chunks + + # add too-small chunks to chunks before them + output = [] + new = 0 + for c in chunks: + if c < size: + if new > size + (size - c): + output.append(new - (size - c)) + new = size + else: + new += c + if new >= size: + output.append(new) + new = 0 + if c >= size: + new += c + if new >= size: + output.append(new) + elif len(output) >= 1: + output[-1] += new + else: + raise ValueError( + f"The overlapping depth {size} is larger than your " + f"array {sum(chunks)}." ) - if not isinstance(axis, Iterable): - axis = (axis,) + return tuple(output) - axis = [ax + a.ndim if ax < 0 else ax for ax in axis] - a = a.rechunk({ax: -1 if ax in axis else "auto" for ax in range(a.ndim)}) +if dask_version > "2021.03.0": + sliding_window_view = da.lib.stride_tricks.sliding_window_view +else: - result = da.map_blocks( - np.nanmedian, - a, - axis=axis, - keepdims=keepdims, - drop_axis=axis if not keepdims else None, - chunks=[1 if ax in axis else c for ax, c in enumerate(a.chunks)] - if keepdims - else None, + def sliding_window_view(x, window_shape, axis=None): + from dask.array.overlap import map_overlap + from numpy.core.numeric import normalize_axis_tuple + + from .npcompat import sliding_window_view as _np_sliding_window_view + + window_shape = ( + tuple(window_shape) if np.iterable(window_shape) else (window_shape,) ) - return result + window_shape_array = np.array(window_shape) + if np.any(window_shape_array <= 0): + raise ValueError("`window_shape` must contain positive values") + + if axis is None: + axis = tuple(range(x.ndim)) + if len(window_shape) != len(axis): + raise ValueError( + f"Since axis is `None`, must provide " + f"window_shape for all dimensions of `x`; " + f"got {len(window_shape)} window_shape elements " + f"and `x.ndim` is {x.ndim}." + ) + else: + axis = normalize_axis_tuple(axis, x.ndim, allow_duplicate=True) + if len(window_shape) != len(axis): + raise ValueError( + f"Must provide matching length window_shape and " + f"axis; got {len(window_shape)} window_shape " + f"elements and {len(axis)} axes elements." + ) + + depths = [0] * x.ndim + for ax, window in zip(axis, window_shape): + depths[ax] += window - 1 + + # Ensure that each chunk is big enough to leave at least a size-1 chunk + # after windowing (this is only really necessary for the last chunk). + safe_chunks = tuple( + ensure_minimum_chunksize(d + 1, c) for d, c in zip(depths, x.chunks) + ) + x = x.rechunk(safe_chunks) + + # result.shape = x_shape_trimmed + window_shape, + # where x_shape_trimmed is x.shape with every entry + # reduced by one less than the corresponding window size. + # trim chunks to match x_shape_trimmed + newchunks = tuple( + c[:-1] + (c[-1] - d,) for d, c in zip(depths, x.chunks) + ) + tuple((window,) for window in window_shape) + + kwargs = dict( + depth=tuple((0, d) for d in depths), # Overlap on +ve side only + boundary="none", + meta=x._meta, + new_axis=range(x.ndim, x.ndim + len(axis)), + chunks=newchunks, + trim=False, + window_shape=window_shape, + axis=axis, + ) + # map_overlap's signature changed in https://github.com/dask/dask/pull/6165 + if dask_version > "2.18.0": + return map_overlap(_np_sliding_window_view, x, align_arrays=False, **kwargs) + else: + return map_overlap(x, _np_sliding_window_view, **kwargs) diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 15641506e4e..5eeb22767c8 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -1,5 +1,3 @@ -import numpy as np - from . import dtypes, nputils @@ -26,92 +24,6 @@ def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1): return result -def rolling_window(a, axis, window, center, fill_value): - """Dask's equivalence to np.utils.rolling_window""" - import dask.array as da - - if not hasattr(axis, "__len__"): - axis = [axis] - window = [window] - center = [center] - - orig_shape = a.shape - depth = {d: 0 for d in range(a.ndim)} - offset = [0] * a.ndim - drop_size = [0] * a.ndim - pad_size = [0] * a.ndim - for ax, win, cent in zip(axis, window, center): - if ax < 0: - ax = a.ndim + ax - depth[ax] = int(win / 2) - # For evenly sized window, we need to crop the first point of each block. - offset[ax] = 1 if win % 2 == 0 else 0 - - if depth[ax] > min(a.chunks[ax]): - raise ValueError( - "For window size %d, every chunk should be larger than %d, " - "but the smallest chunk size is %d. Rechunk your array\n" - "with a larger chunk size or a chunk size that\n" - "more evenly divides the shape of your array." - % (win, depth[ax], min(a.chunks[ax])) - ) - - # Although da.overlap pads values to boundaries of the array, - # the size of the generated array is smaller than what we want - # if center == False. - if cent: - start = int(win / 2) # 10 -> 5, 9 -> 4 - end = win - 1 - start - else: - start, end = win - 1, 0 - pad_size[ax] = max(start, end) + offset[ax] - depth[ax] - drop_size[ax] = 0 - # pad_size becomes more than 0 when the overlapped array is smaller than - # needed. In this case, we need to enlarge the original array by padding - # before overlapping. - if pad_size[ax] > 0: - if pad_size[ax] < depth[ax]: - # overlapping requires each chunk larger than depth. If pad_size is - # smaller than the depth, we enlarge this and truncate it later. - drop_size[ax] = depth[ax] - pad_size[ax] - pad_size[ax] = depth[ax] - - # TODO maybe following two lines can be summarized. - a = da.pad( - a, [(p, 0) for p in pad_size], mode="constant", constant_values=fill_value - ) - boundary = {d: fill_value for d in range(a.ndim)} - - # create overlap arrays - ag = da.overlap.overlap(a, depth=depth, boundary=boundary) - - def func(x, window, axis): - x = np.asarray(x) - index = [slice(None)] * x.ndim - for ax, win in zip(axis, window): - x = nputils._rolling_window(x, win, ax) - index[ax] = slice(offset[ax], None) - return x[tuple(index)] - - chunks = list(a.chunks) + window - new_axis = [a.ndim + i for i in range(len(axis))] - out = da.map_blocks( - func, - ag, - dtype=a.dtype, - new_axis=new_axis, - chunks=chunks, - window=window, - axis=axis, - ) - - # crop boundary. - index = [slice(None)] * a.ndim - for ax in axis: - index[ax] = slice(drop_size[ax], drop_size[ax] + orig_shape[ax]) - return out[tuple(index)] - - def least_squares(lhs, rhs, rcond=None, skipna=False): import dask.array as da @@ -139,3 +51,30 @@ def least_squares(lhs, rhs, rcond=None, skipna=False): # See issue dask/dask#6516 coeffs, residuals, _, _ = da.linalg.lstsq(lhs_da, rhs) return coeffs, residuals + + +def push(array, n, axis): + """ + Dask-aware bottleneck.push + """ + from bottleneck import push + + if len(array.chunks[axis]) > 1 and n is not None and n < array.shape[axis]: + raise NotImplementedError( + "Cannot fill along a chunked axis when limit is not None." + "Either rechunk to a single chunk along this axis or call .compute() or .load() first." + ) + if all(c == 1 for c in array.chunks[axis]): + array = array.rechunk({axis: 2}) + pushed = array.map_blocks(push, axis=axis, n=n, dtype=array.dtype, meta=array._meta) + if len(array.chunks[axis]) > 1: + pushed = pushed.map_overlap( + push, + axis=axis, + n=n, + depth={axis: (1, 0)}, + boundary="none", + dtype=array.dtype, + meta=array._meta, + ) + return pushed diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 0155cdc4e19..900af885319 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1,7 +1,5 @@ import datetime -import functools import warnings -from numbers import Number from typing import ( TYPE_CHECKING, Any, @@ -43,7 +41,9 @@ align, reindex_like_indexers, ) +from .arithmetic import DataArrayArithmetic from .common import AbstractArray, DataWithCoords +from .computation import unify_chunks from .coordinates import ( DataArrayCoordinates, assert_coordinate_consistent, @@ -51,7 +51,7 @@ ) from .dataset import Dataset, split_indexes from .formatting import format_item -from .indexes import Indexes, default_indexes, propagate_indexes +from .indexes import Index, Indexes, default_indexes, propagate_indexes from .indexing import is_fancy_indexer from .merge import PANDAS_TYPES, MergeError, _extract_indexes_from_coords from .options import OPTIONS, _get_keep_attrs @@ -70,9 +70,9 @@ assert_unique_multiindex_level_names, ) +T_DataArray = TypeVar("T_DataArray", bound="DataArray") +T_DSorDA = TypeVar("T_DSorDA", "DataArray", Dataset) if TYPE_CHECKING: - T_DSorDA = TypeVar("T_DSorDA", "DataArray", Dataset) - try: from dask.delayed import Delayed except ImportError: @@ -98,39 +98,34 @@ def _infer_coords_and_dims( and len(coords) != len(shape) ): raise ValueError( - "coords is not dict-like, but it has %s items, " - "which does not match the %s dimensions of the " - "data" % (len(coords), len(shape)) + f"coords is not dict-like, but it has {len(coords)} items, " + f"which does not match the {len(shape)} dimensions of the " + "data" ) if isinstance(dims, str): dims = (dims,) if dims is None: - dims = ["dim_%s" % n for n in range(len(shape))] + dims = [f"dim_{n}" for n in range(len(shape))] if coords is not None and len(coords) == len(shape): # try to infer dimensions from coords if utils.is_dict_like(coords): - # deprecated in GH993, removed in GH1539 - raise ValueError( - "inferring DataArray dimensions from " - "dictionary like ``coords`` is no longer " - "supported. Use an explicit list of " - "``dims`` instead." - ) - for n, (dim, coord) in enumerate(zip(dims, coords)): - coord = as_variable(coord, name=dims[n]).to_index_variable() - dims[n] = coord.name + dims = list(coords.keys()) + else: + for n, (dim, coord) in enumerate(zip(dims, coords)): + coord = as_variable(coord, name=dims[n]).to_index_variable() + dims[n] = coord.name dims = tuple(dims) elif len(dims) != len(shape): raise ValueError( "different number of dimensions on data " - "and dims: %s vs %s" % (len(shape), len(dims)) + f"and dims: {len(shape)} vs {len(dims)}" ) else: for d in dims: if not isinstance(d, str): - raise TypeError("dimension %s is not a string" % d) + raise TypeError(f"dimension {d} is not a string") new_coords: Dict[Any, Variable] = {} @@ -147,24 +142,24 @@ def _infer_coords_and_dims( for k, v in new_coords.items(): if any(d not in dims for d in v.dims): raise ValueError( - "coordinate %s has dimensions %s, but these " + f"coordinate {k} has dimensions {v.dims}, but these " "are not a subset of the DataArray " - "dimensions %s" % (k, v.dims, dims) + f"dimensions {dims}" ) for d, s in zip(v.dims, v.shape): if s != sizes[d]: raise ValueError( - "conflicting sizes for dimension %r: " - "length %s on the data but length %s on " - "coordinate %r" % (d, sizes[d], s, k) + f"conflicting sizes for dimension {d!r}: " + f"length {sizes[d]} on the data but length {s} on " + f"coordinate {k!r}" ) if k in sizes and v.shape != (sizes[k],): raise ValueError( - "coordinate %r is a DataArray dimension, but " - "it has shape %r rather than expected shape %r " - "matching the dimension size" % (k, v.shape, (sizes[k],)) + f"coordinate {k!r} is a DataArray dimension, but " + f"it has shape {v.shape!r} rather than expected shape {sizes[k]!r} " + "matching the dimension size" ) assert_unique_multiindex_level_names(new_coords) @@ -218,7 +213,7 @@ def __setitem__(self, key, value) -> None: _THIS_ARRAY = ReprObject("") -class DataArray(AbstractArray, DataWithCoords): +class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic): """N-dimensional array with labeled coordinates and dimensions. DataArray provides a wrapper around numpy ndarrays that uses @@ -275,7 +270,8 @@ class DataArray(AbstractArray, DataWithCoords): Name(s) of the data dimension(s). Must be either a hashable (only for 1D data) or a sequence of hashables with length equal to the number of dimensions. If this argument is omitted, - dimension names default to ``['dim_0', ... 'dim_n']``. + dimension names are taken from ``coords`` (if possible) and + otherwise default to ``['dim_0', ... 'dim_n']``. name : str or None, optional Name of this array. attrs : dict_like or None, optional @@ -288,7 +284,6 @@ class DataArray(AbstractArray, DataWithCoords): >>> np.random.seed(0) >>> temperature = 15 + 8 * np.random.randn(2, 2, 3) - >>> precipitation = 10 * np.random.rand(2, 2, 3) >>> lon = [[-99.83, -99.32], [-99.79, -99.23]] >>> lat = [[42.25, 42.21], [42.63, 42.59]] >>> time = pd.date_range("2014-09-06", periods=3) @@ -345,7 +340,7 @@ class DataArray(AbstractArray, DataWithCoords): _cache: Dict[str, Any] _coords: Dict[Any, Variable] _close: Optional[Callable[[], None]] - _indexes: Optional[Dict[Hashable, pd.Index]] + _indexes: Optional[Dict[Hashable, Index]] _name: Optional[Hashable] _variable: Variable @@ -425,12 +420,12 @@ def __init__( self._close = None def _replace( - self, + self: T_DataArray, variable: Variable = None, coords=None, name: Union[Hashable, None, Default] = _default, indexes=None, - ) -> "DataArray": + ) -> T_DataArray: if variable is None: variable = self.variable if coords is None: @@ -472,13 +467,14 @@ def _overwrite_indexes(self, indexes: Mapping[Hashable, Any]) -> "DataArray": return self coords = self._coords.copy() for name, idx in indexes.items(): - coords[name] = IndexVariable(name, idx) + coords[name] = IndexVariable(name, idx.to_pandas_index()) obj = self._replace(coords=coords) # switch from dimension to level names, if necessary dim_names: Dict[Any, str] = {} for dim, idx in indexes.items(): - if not isinstance(idx, pd.MultiIndex) and idx.name != dim: + pd_idx = idx.to_pandas_index() + if not isinstance(idx, pd.MultiIndex) and pd_idx.name != dim: dim_names[dim] = idx.name if dim_names: obj = obj.rename(dim_names) @@ -496,7 +492,7 @@ def _from_temp_dataset( return self._replace(variable, coords, name, indexes=indexes) def _to_dataset_split(self, dim: Hashable) -> Dataset: - """ splits dataarray along dimension 'dim' """ + """splits dataarray along dimension 'dim'""" def subset(dim, label): array = self.loc[{dim: label}] @@ -537,8 +533,7 @@ def _to_dataset_whole( indexes = self._indexes coord_names = set(self._coords) - dataset = Dataset._construct_direct(variables, coord_names, indexes=indexes) - return dataset + return Dataset._construct_direct(variables, coord_names, indexes=indexes) def to_dataset( self, @@ -621,7 +616,16 @@ def __len__(self) -> int: @property def data(self) -> Any: - """The array's data as a dask or numpy array""" + """ + The DataArray's data as an array. The underlying array type + (e.g. dask, sparse, pint) is preserved. + + See Also + -------- + DataArray.to_numpy + DataArray.as_numpy + DataArray.values + """ return self.variable.data @data.setter @@ -630,13 +634,46 @@ def data(self, value: Any) -> None: @property def values(self) -> np.ndarray: - """The array's data as a numpy.ndarray""" + """ + The array's data as a numpy.ndarray. + + If the array's data is not a numpy.ndarray this will attempt to convert + it naively using np.array(), which will raise an error if the array + type does not support coercion like this (e.g. cupy). + """ return self.variable.values @values.setter def values(self, value: Any) -> None: self.variable.values = value + def to_numpy(self) -> np.ndarray: + """ + Coerces wrapped data to numpy and returns a numpy.ndarray. + + See also + -------- + DataArray.as_numpy : Same but returns the surrounding DataArray instead. + Dataset.as_numpy + DataArray.values + DataArray.data + """ + return self.variable.to_numpy() + + def as_numpy(self: T_DataArray) -> T_DataArray: + """ + Coerces wrapped data and coordinates into numpy arrays, returning a DataArray. + + See also + -------- + DataArray.to_numpy : Same but returns only the data as a numpy.ndarray object. + Dataset.as_numpy : Converts all variables in a Dataset. + DataArray.values + DataArray.data + """ + coords = {k: v.as_numpy() for k, v in self._coords.items()} + return self._replace(self.variable.as_numpy(), coords, indexes=self._indexes) + @property def _in_memory(self) -> bool: return self.variable._in_memory @@ -667,9 +704,8 @@ def dims(self, value): def _item_key_to_dict(self, key: Any) -> Mapping[Hashable, Any]: if utils.is_dict_like(key): return key - else: - key = indexing.expanded_indexer(key, self.ndim) - return dict(zip(self.dims, key)) + key = indexing.expanded_indexer(key, self.ndim) + return dict(zip(self.dims, key)) @property def _level_coords(self) -> Dict[Hashable, Hashable]: @@ -758,7 +794,7 @@ def attrs(self) -> Dict[Hashable, Any]: @attrs.setter def attrs(self, value: Mapping[Hashable, Any]) -> None: # Disable type checking to work around mypy bug - see mypy#4167 - self.variable.attrs = value # type: ignore + self.variable.attrs = value # type: ignore[assignment] @property def encoding(self) -> Dict[Hashable, Any]: @@ -772,7 +808,21 @@ def encoding(self, value: Mapping[Hashable, Any]) -> None: @property def indexes(self) -> Indexes: - """Mapping of pandas.Index objects used for label based indexing""" + """Mapping of pandas.Index objects used for label based indexing. + + Raises an error if this Dataset has indexes that cannot be coerced + to pandas.Index objects. + + See Also + -------- + DataArray.xindexes + + """ + return Indexes({k: idx.to_pandas_index() for k, idx in self.xindexes.items()}) + + @property + def xindexes(self) -> Indexes: + """Mapping of xarray Index objects used for label based indexing.""" if self._indexes is None: self._indexes = default_indexes(self._coords, self.dims) return Indexes(self._indexes) @@ -807,13 +857,12 @@ def reset_coords( dataset = self.coords.to_dataset().reset_coords(names, drop) if drop: return self._replace(coords=dataset._variables) - else: - if self.name is None: - raise ValueError( - "cannot reset_coords with drop=False on an unnamed DataArrray" - ) - dataset[self.name] = self.variable - return dataset + if self.name is None: + raise ValueError( + "cannot reset_coords with drop=False on an unnamed DataArrray" + ) + dataset[self.name] = self.variable + return dataset def __dask_tokenize__(self): from dask.base import normalize_token @@ -839,15 +888,15 @@ def __dask_scheduler__(self): def __dask_postcompute__(self): func, args = self._to_temp_dataset().__dask_postcompute__() - return self._dask_finalize, (func, args, self.name) + return self._dask_finalize, (self.name, func) + args def __dask_postpersist__(self): func, args = self._to_temp_dataset().__dask_postpersist__() - return self._dask_finalize, (func, args, self.name) + return self._dask_finalize, (self.name, func) + args @staticmethod - def _dask_finalize(results, func, args, name): - ds = func(results, *args) + def _dask_finalize(results, name, func, *args, **kwargs): + ds = func(results, *args, **kwargs) variable = ds._variables.pop(_THIS_ARRAY) coords = ds._variables return DataArray(variable, coords, name=name, fastpath=True) @@ -917,7 +966,7 @@ def persist(self, **kwargs) -> "DataArray": ds = self._to_temp_dataset().persist(**kwargs) return self._from_temp_dataset(ds) - def copy(self, deep: bool = True, data: Any = None) -> "DataArray": + def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray: """Returns a copy of this array. If `deep=True`, a deep copy is made of the data array. @@ -945,7 +994,6 @@ def copy(self, deep: bool = True, data: Any = None) -> "DataArray": Examples -------- - Shallow versus deep copy >>> array = xr.DataArray([1, 2, 3], dims="x", coords={"x": ["a", "b", "c"]}) @@ -1004,7 +1052,7 @@ def __deepcopy__(self, memo=None) -> "DataArray": # mutable objects should not be hashable # https://github.com/python/mypy/issues/4266 - __hash__ = None # type: ignore + __hash__ = None # type: ignore[assignment] @property def chunks(self) -> Optional[Tuple[Tuple[int, ...], ...]]: @@ -1016,10 +1064,10 @@ def chunks(self) -> Optional[Tuple[Tuple[int, ...], ...]]: def chunk( self, chunks: Union[ - Number, - Tuple[Number, ...], - Tuple[Tuple[Number, ...], ...], - Mapping[Hashable, Union[None, Number, Tuple[Number, ...]]], + int, + Tuple[int, ...], + Tuple[Tuple[int, ...], ...], + Mapping[Hashable, Union[None, int, Tuple[int, ...]]], ] = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) name_prefix: str = "xarray-", token: str = None, @@ -1086,7 +1134,7 @@ def isel( What to do if dimensions that should be selected from are not present in the DataArray: - "raise": raise an exception - - "warning": raise a warning, and ignore the missing dimensions + - "warn": raise a warning, and ignore the missing dimensions - "ignore": ignore the missing dimensions **indexers_kwargs : {dim: indexer, ...}, optional The keyword arguments form of ``indexers``. @@ -1095,6 +1143,26 @@ def isel( -------- Dataset.isel DataArray.sel + + Examples + -------- + >>> da = xr.DataArray(np.arange(25).reshape(5, 5), dims=("x", "y")) + >>> da + + array([[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24]]) + Dimensions without coordinates: x, y + + >>> tgt_x = xr.DataArray(np.arange(0, 5), dims="points") + >>> tgt_y = xr.DataArray(np.arange(0, 5), dims="points") + >>> da = da.isel(x=tgt_x, y=tgt_y) + >>> da + + array([ 0, 6, 12, 18, 24]) + Dimensions without coordinates: points """ indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel") @@ -1203,6 +1271,34 @@ def sel( Dataset.sel DataArray.isel + Examples + -------- + >>> da = xr.DataArray( + ... np.arange(25).reshape(5, 5), + ... coords={"x": np.arange(5), "y": np.arange(5)}, + ... dims=("x", "y"), + ... ) + >>> da + + array([[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24]]) + Coordinates: + * x (x) int64 0 1 2 3 4 + * y (y) int64 0 1 2 3 4 + + >>> tgt_x = xr.DataArray(np.linspace(0, 4, num=5), dims="points") + >>> tgt_y = xr.DataArray(np.linspace(0, 4, num=5), dims="points") + >>> da = da.sel(x=tgt_x, y=tgt_y, method="nearest") + >>> da + + array([ 0, 6, 12, 18, 24]) + Coordinates: + x (points) int64 0 1 2 3 4 + y (points) int64 0 1 2 3 4 + Dimensions without coordinates: points """ ds = self._to_temp_dataset().sel( indexers=indexers, @@ -1294,7 +1390,6 @@ def broadcast_like( Examples -------- - >>> arr1 = xr.DataArray( ... np.random.randn(2, 3), ... dims=("x", "y"), @@ -1452,6 +1547,26 @@ def reindex( Another dataset array, with this array's data but replaced coordinates. + Examples + -------- + Reverse latitude: + + >>> da = xr.DataArray( + ... np.arange(4), + ... coords=[np.array([90, 89, 88, 87])], + ... dims="lat", + ... ) + >>> da + + array([0, 1, 2, 3]) + Coordinates: + * lat (lat) int64 90 89 88 87 + >>> da.reindex(lat=da.lat[::-1]) + + array([3, 2, 1, 0]) + Coordinates: + * lat (lat) int64 87 88 89 90 + See Also -------- DataArray.reindex_like @@ -1709,8 +1824,7 @@ def swap_dims( dims_dict : dict-like Dictionary whose keys are current dimension names and whose values are new names. - - **dim_kwargs : {dim: , ...}, optional + **dims_kwargs : {existing_dim: new_dim, ...}, optional The keyword arguments form of ``dims_dict``. One of dims_dict or dims_kwargs must be provided. @@ -1721,7 +1835,6 @@ def swap_dims( Examples -------- - >>> arr = xr.DataArray( ... data=[0, 1], ... dims="x", @@ -1751,7 +1864,6 @@ def swap_dims( See Also -------- - DataArray.rename Dataset.swap_dims """ @@ -1769,7 +1881,6 @@ def expand_dims( the corresponding position in the array shape. The new object is a view into the underlying array, not a copy. - If dim is already a scalar coordinate, it will be promoted to a 1D coordinate consisting of a single value. @@ -1817,7 +1928,7 @@ def set_index( indexes: Mapping[Hashable, Union[Hashable, Sequence[Hashable]]] = None, append: bool = False, **indexes_kwargs: Union[Hashable, Sequence[Hashable]], - ) -> Optional["DataArray"]: + ) -> "DataArray": """Set DataArray (multi-)indexes using one or more existing coordinates. @@ -1873,7 +1984,7 @@ def reset_index( self, dims_or_levels: Union[Hashable, Sequence[Hashable]], drop: bool = False, - ) -> Optional["DataArray"]: + ) -> "DataArray": """Reset the specified index(es) or multi-index level(s). Parameters @@ -1929,7 +2040,7 @@ def reorder_levels( coord = self._coords[dim] index = coord.to_index() if not isinstance(index, pd.MultiIndex): - raise ValueError("coordinate %r has no MultiIndex" % dim) + raise ValueError(f"coordinate {dim!r} has no MultiIndex") replace_coords[dim] = IndexVariable(coord.dims, index.reorder_levels(order)) coords = self._coords.copy() coords.update(replace_coords) @@ -1965,7 +2076,6 @@ def stack( Examples -------- - >>> arr = xr.DataArray( ... np.arange(6).reshape(2, 3), ... coords=[("x", ["a", "b"]), ("y", [0, 1, 2])], @@ -2026,7 +2136,6 @@ def unstack( Examples -------- - >>> arr = xr.DataArray( ... np.arange(6).reshape(2, 3), ... coords=[("x", ["a", "b"]), ("y", [0, 1, 2])], @@ -2071,9 +2180,6 @@ def to_unstacked_dataset(self, dim, level=0): level : int or str The MultiIndex level to expand to a dataset along. Can either be the integer index of the level or its name. - label : int, default: 0 - Label of the level to expand dataset along. Overrides the label - argument if given. Returns ------- @@ -2081,7 +2187,6 @@ def to_unstacked_dataset(self, dim, level=0): Examples -------- - >>> import xarray as xr >>> arr = xr.DataArray( ... np.arange(6).reshape(2, 3), ... coords=[("x", ["a", "b"]), ("y", [0, 1, 2])], @@ -2112,7 +2217,9 @@ def to_unstacked_dataset(self, dim, level=0): Dataset.to_stacked_array """ - idx = self.indexes[dim] + # TODO: benbovy - flexible indexes: update when MultIndex has its own + # class inheriting from xarray.Index + idx = self.xindexes[dim].to_pandas_index() if not isinstance(idx, pd.MultiIndex): raise ValueError(f"'{dim}' is not a stacked coordinate") @@ -2147,7 +2254,7 @@ def transpose( What to do if dimensions that should be selected from are not present in the DataArray: - "raise": raise an exception - - "warning": raise a warning, and ignore the missing dimensions + - "warn": raise a warning, and ignore the missing dimensions - "ignore": ignore the missing dimensions Returns @@ -2191,7 +2298,7 @@ def drop_vars( ---------- names : hashable or iterable of hashable Name(s) of variables to drop. - errors: {"raise", "ignore"}, optional + errors : {"raise", "ignore"}, optional If 'raise' (default), raises a ValueError error if any of the variable passed are not in the dataset. If 'ignore', any given names that are in the DataArray are dropped and no error is raised. @@ -2357,7 +2464,6 @@ def interpolate_na( provided. - 'barycentric', 'krog', 'pchip', 'spline', 'akima': use their respective :py:class:`scipy.interpolate` classes. - use_coordinate : bool or str, default: True Specifies which index to use as the x values in the interpolation formulated as `y = f(x)`. If False, values are treated as if @@ -2369,7 +2475,7 @@ def interpolate_na( or None for no limit. This filling is done regardless of the size of the gap in the data. To only interpolate over gaps less than a given length, see ``max_gap``. - max_gap: int, float, str, pandas.Timedelta, numpy.timedelta64, datetime.timedelta, default: None + max_gap : int, float, str, pandas.Timedelta, numpy.timedelta64, datetime.timedelta, default: None Maximum size of gap, a continuous sequence of NaNs, that will be filled. Use None for no limit. When interpolating along a datetime64 dimension and ``use_coordinate=True``, ``max_gap`` can be one of the following: @@ -2396,7 +2502,7 @@ def interpolate_na( If True, the dataarray's attributes (`attrs`) will be copied from the original object to the new one. If False, the new object will be returned without attributes. - kwargs : dict, optional + **kwargs : dict, optional parameters passed verbatim to the underlying interpolation function Returns @@ -2404,7 +2510,7 @@ def interpolate_na( interpolated: DataArray Filled in DataArray. - See also + See Also -------- numpy.interp scipy.interpolate @@ -2459,7 +2565,8 @@ def ffill(self, dim: Hashable, limit: int = None) -> "DataArray": The maximum number of consecutive NaN values to forward fill. In other words, if there is a gap with more than this number of consecutive NaNs, it will only be partially filled. Must be greater - than 0 or None for no limit. + than 0 or None for no limit. Must be None or greater than or equal + to axis length if filling along chunked axes (dimensions). Returns ------- @@ -2483,7 +2590,8 @@ def bfill(self, dim: Hashable, limit: int = None) -> "DataArray": The maximum number of consecutive NaN values to backward fill. In other words, if there is a gap with more than this number of consecutive NaNs, it will only be partially filled. Must be greater - than 0 or None for no limit. + than 0 or None for no limit. Must be None or greater than or equal + to axis length if filling along chunked axes (dimensions). Returns ------- @@ -2577,8 +2685,8 @@ def to_pandas(self) -> Union["DataArray", pd.Series, pd.DataFrame]: constructor = constructors[self.ndim] except KeyError: raise ValueError( - "cannot convert arrays with %s dimensions into " - "pandas objects" % self.ndim + f"cannot convert arrays with {self.ndim} dimensions into " + "pandas objects" ) indexes = [self.get_index(dim) for dim in self.dims] return constructor(self.values, *indexes) @@ -2664,7 +2772,7 @@ def to_masked_array(self, copy: bool = True) -> np.ma.MaskedArray: result : MaskedArray Masked where invalid values (nan or inf) occur. """ - values = self.values # only compute lazy arrays once + values = self.to_numpy() # only compute lazy arrays once isnull = pd.isnull(values) return np.ma.MaskedArray(data=values, mask=isnull, copy=copy) @@ -2716,7 +2824,7 @@ def to_dict(self, data: bool = True) -> dict: Whether to include the actual data in the dictionary. When set to False, returns just the schema. - See also + See Also -------- DataArray.from_dict """ @@ -2757,7 +2865,7 @@ def from_dict(cls, d: dict) -> "DataArray": ------- obj : xarray.DataArray - See also + See Also -------- DataArray.to_dict Dataset.from_dict @@ -2794,7 +2902,7 @@ def from_series(cls, series: pd.Series, sparse: bool = False) -> "DataArray": If sparse=True, creates a sparse array instead of a dense NumPy array. Requires the pydata/sparse package. - See also + See Also -------- xarray.Dataset.from_dataframe """ @@ -2911,82 +3019,67 @@ def __rmatmul__(self, other): # compatible with matmul return computation.dot(other, self) - @staticmethod - def _unary_op(f: Callable[..., Any]) -> Callable[..., "DataArray"]: - @functools.wraps(f) - def func(self, *args, **kwargs): - keep_attrs = kwargs.pop("keep_attrs", None) - if keep_attrs is None: - keep_attrs = _get_keep_attrs(default=True) - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") - warnings.filterwarnings( - "ignore", r"Mean of empty slice", category=RuntimeWarning - ) - with np.errstate(all="ignore"): - da = self.__array_wrap__(f(self.variable.data, *args, **kwargs)) - if keep_attrs: - da.attrs = self.attrs - return da - - return func + def _unary_op(self, f: Callable, *args, **kwargs): + keep_attrs = kwargs.pop("keep_attrs", None) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") + warnings.filterwarnings( + "ignore", r"Mean of empty slice", category=RuntimeWarning + ) + with np.errstate(all="ignore"): + da = self.__array_wrap__(f(self.variable.data, *args, **kwargs)) + if keep_attrs: + da.attrs = self.attrs + return da - @staticmethod def _binary_op( - f: Callable[..., Any], + self, + other, + f: Callable, reflexive: bool = False, - join: str = None, # see xarray.align - **ignored_kwargs, - ) -> Callable[..., "DataArray"]: - @functools.wraps(f) - def func(self, other): - if isinstance(other, (Dataset, groupby.GroupBy)): - return NotImplemented - if isinstance(other, DataArray): - align_type = OPTIONS["arithmetic_join"] if join is None else join - self, other = align(self, other, join=align_type, copy=False) - other_variable = getattr(other, "variable", other) - other_coords = getattr(other, "coords", None) - - variable = ( - f(self.variable, other_variable) - if not reflexive - else f(other_variable, self.variable) - ) - coords, indexes = self.coords._merge_raw(other_coords) - name = self._result_name(other) - - return self._replace(variable, coords, name, indexes=indexes) - - return func + ): + if isinstance(other, (Dataset, groupby.GroupBy)): + return NotImplemented + if isinstance(other, DataArray): + align_type = OPTIONS["arithmetic_join"] + self, other = align(self, other, join=align_type, copy=False) + other_variable = getattr(other, "variable", other) + other_coords = getattr(other, "coords", None) + + variable = ( + f(self.variable, other_variable) + if not reflexive + else f(other_variable, self.variable) + ) + coords, indexes = self.coords._merge_raw(other_coords, reflexive) + name = self._result_name(other) - @staticmethod - def _inplace_binary_op(f: Callable) -> Callable[..., "DataArray"]: - @functools.wraps(f) - def func(self, other): - if isinstance(other, groupby.GroupBy): - raise TypeError( - "in-place operations between a DataArray and " - "a grouped object are not permitted" - ) - # n.b. we can't align other to self (with other.reindex_like(self)) - # because `other` may be converted into floats, which would cause - # in-place arithmetic to fail unpredictably. Instead, we simply - # don't support automatic alignment with in-place arithmetic. - other_coords = getattr(other, "coords", None) - other_variable = getattr(other, "variable", other) - try: - with self.coords._merge_inplace(other_coords): - f(self.variable, other_variable) - except MergeError as exc: - raise MergeError( - "Automatic alignment is not supported for in-place operations.\n" - "Consider aligning the indices manually or using a not-in-place operation.\n" - "See https://github.com/pydata/xarray/issues/3910 for more explanations." - ) from exc - return self + return self._replace(variable, coords, name, indexes=indexes) - return func + def _inplace_binary_op(self, other, f: Callable): + if isinstance(other, groupby.GroupBy): + raise TypeError( + "in-place operations between a DataArray and " + "a grouped object are not permitted" + ) + # n.b. we can't align other to self (with other.reindex_like(self)) + # because `other` may be converted into floats, which would cause + # in-place arithmetic to fail unpredictably. Instead, we simply + # don't support automatic alignment with in-place arithmetic. + other_coords = getattr(other, "coords", None) + other_variable = getattr(other, "variable", other) + try: + with self.coords._merge_inplace(other_coords): + f(self.variable, other_variable) + except MergeError as exc: + raise MergeError( + "Automatic alignment is not supported for in-place operations.\n" + "Consider aligning the indices manually or using a not-in-place operation.\n" + "See https://github.com/pydata/xarray/issues/3910 for more explanations." + ) from exc + return self def _copy_attrs_from(self, other: Union["DataArray", Dataset, Variable]) -> None: self.attrs = other.attrs @@ -3047,7 +3140,6 @@ def diff(self, dim: Hashable, n: int = 1, label: Hashable = "upper") -> "DataArr `n` matches numpy's behavior and is different from pandas' first argument named `periods`. - Examples -------- >>> arr = xr.DataArray([5, 5, 6, 6], [[1, 2, 3, 4]], ["x"]) @@ -3087,7 +3179,7 @@ def shift( Integer offset to shift along each of the given dimensions. Positive offsets shift to the right; negative offsets shift to the left. - fill_value: scalar, optional + fill_value : scalar, optional Value to use for newly missing values **shifts_kwargs The keyword arguments form of ``shifts``. @@ -3099,13 +3191,12 @@ def shift( DataArray with the same coordinates and attributes but shifted data. - See also + See Also -------- roll Examples -------- - >>> arr = xr.DataArray([5, 6, 7], dims="x") >>> arr.shift(x=1) @@ -3149,13 +3240,12 @@ def roll( rolled : DataArray DataArray with the same attributes but rolled data and coordinates. - See also + See Also -------- shift Examples -------- - >>> arr = xr.DataArray([5, 6, 7], dims="x") >>> arr.roll(x=1) @@ -3195,14 +3285,13 @@ def dot( result : DataArray Array resulting from the dot product over all shared dimensions. - See also + See Also -------- dot numpy.tensordot Examples -------- - >>> da_vals = np.arange(6 * 5 * 4).reshape((6, 5, 4)) >>> da = xr.DataArray(da_vals, dims=["x", "y", "z"]) >>> dm_vals = np.arange(4) @@ -3265,7 +3354,6 @@ def sortby( Examples -------- - >>> da = xr.DataArray( ... np.random.rand(5), ... coords=[pd.date_range("1/1/2000", periods=5)], @@ -3338,7 +3426,6 @@ def quantile( Examples -------- - >>> da = xr.DataArray( ... data=[[0.7, 4.2, 9.4, 1.5], [6.5, 7.3, 2.6, 1.9]], ... coords={"x": [7, 9], "y": [1, 1.5, 2, 2.5]}, @@ -3410,7 +3497,6 @@ def rank( Examples -------- - >>> arr = xr.DataArray([5, 6, 7], dims="x") >>> arr.rank("x") @@ -3484,8 +3570,6 @@ def integrate( self, coord: Union[Hashable, Sequence[Hashable]] = None, datetime_unit: str = None, - *, - dim: Union[Hashable, Sequence[Hashable]] = None, ) -> "DataArray": """Integrate along the given coordinate using the trapezoidal rule. @@ -3495,21 +3579,20 @@ def integrate( Parameters ---------- - coord: hashable, or a sequence of hashable - Coordinate(s) used for the integration. - dim : hashable, or sequence of hashable + coord : hashable, or sequence of hashable Coordinate(s) used for the integration. - datetime_unit: {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', \ + datetime_unit : {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', \ 'ps', 'fs', 'as'}, optional + Specify the unit if a datetime coordinate is used. Returns ------- - integrated: DataArray + integrated : DataArray See also -------- Dataset.integrate - numpy.trapz: corresponding numpy function + numpy.trapz : corresponding numpy function Examples -------- @@ -3534,22 +3617,69 @@ def integrate( array([5.4, 6.6, 7.8]) Dimensions without coordinates: y """ - if dim is not None and coord is not None: - raise ValueError( - "Cannot pass both 'dim' and 'coord'. Please pass only 'coord' instead." - ) + ds = self._to_temp_dataset().integrate(coord, datetime_unit) + return self._from_temp_dataset(ds) - if dim is not None and coord is None: - coord = dim - msg = ( - "The `dim` keyword argument to `DataArray.integrate` is " - "being replaced with `coord`, for consistency with " - "`Dataset.integrate`. Please pass `coord` instead." - " `dim` will be removed in version 0.19.0." - ) - warnings.warn(msg, FutureWarning, stacklevel=2) + def cumulative_integrate( + self, + coord: Union[Hashable, Sequence[Hashable]] = None, + datetime_unit: str = None, + ) -> "DataArray": + """Integrate cumulatively along the given coordinate using the trapezoidal rule. - ds = self._to_temp_dataset().integrate(coord, datetime_unit) + .. note:: + This feature is limited to simple cartesian geometry, i.e. coord + must be one dimensional. + + The first entry of the cumulative integral is always 0, in order to keep the + length of the dimension unchanged between input and output. + + Parameters + ---------- + coord : hashable, or sequence of hashable + Coordinate(s) used for the integration. + datetime_unit : {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', \ + 'ps', 'fs', 'as'}, optional + Specify the unit if a datetime coordinate is used. + + Returns + ------- + integrated : DataArray + + See also + -------- + Dataset.cumulative_integrate + scipy.integrate.cumulative_trapezoid : corresponding scipy function + + Examples + -------- + + >>> da = xr.DataArray( + ... np.arange(12).reshape(4, 3), + ... dims=["x", "y"], + ... coords={"x": [0, 0.1, 1.1, 1.2]}, + ... ) + >>> da + + array([[ 0, 1, 2], + [ 3, 4, 5], + [ 6, 7, 8], + [ 9, 10, 11]]) + Coordinates: + * x (x) float64 0.0 0.1 1.1 1.2 + Dimensions without coordinates: y + >>> + >>> da.cumulative_integrate("x") + + array([[0. , 0. , 0. ], + [0.15, 0.25, 0.35], + [4.65, 5.75, 6.85], + [5.4 , 6.6 , 7.8 ]]) + Coordinates: + * x (x) float64 0.0 0.1 1.1 1.2 + Dimensions without coordinates: y + """ + ds = self._to_temp_dataset().cumulative_integrate(coord, datetime_unit) return self._from_temp_dataset(ds) def unify_chunks(self) -> "DataArray": @@ -3557,24 +3687,22 @@ def unify_chunks(self) -> "DataArray": Returns ------- - DataArray with consistent chunk sizes for all dask-array variables See Also -------- - dask.array.core.unify_chunks """ - ds = self._to_temp_dataset().unify_chunks() - return self._from_temp_dataset(ds) + + return unify_chunks(self)[0] def map_blocks( self, - func: "Callable[..., T_DSorDA]", + func: Callable[..., T_DSorDA], args: Sequence[Any] = (), kwargs: Mapping[str, Any] = None, template: Union["DataArray", "Dataset"] = None, - ) -> "T_DSorDA": + ) -> T_DSorDA: """ Apply a function to each block of this DataArray. @@ -3615,20 +3743,19 @@ def map_blocks( Notes ----- This function is designed for when ``func`` needs to manipulate a whole xarray object - subset to each block. In the more common case where ``func`` can work on numpy arrays, it is - recommended to use ``apply_ufunc``. + subset to each block. Each block is loaded into memory. In the more common case where + ``func`` can work on numpy arrays, it is recommended to use ``apply_ufunc``. If none of the variables in this object is backed by dask arrays, calling this function is equivalent to calling ``func(obj, *args, **kwargs)``. See Also -------- - dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks, + dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks xarray.DataArray.map_blocks Examples -------- - Calculate an anomaly from climatology using ``.groupby()``. Using ``xr.map_blocks()`` allows for parallel operations with knowledge of ``xarray``, its indices, and its methods like ``.groupby()``. @@ -3664,7 +3791,7 @@ def map_blocks( ... calculate_anomaly, kwargs={"groupby_type": "time.year"}, template=array ... ) # doctest: +ELLIPSIS - dask.array + dask.array<-calculate_anomaly, shape=(24,), dtype=float64, chunksize=(24,), chunktype=numpy.ndarray> Coordinates: * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 month (time) int64 dask.array @@ -3728,9 +3855,11 @@ def polyfit( polyfit_covariance The covariance matrix of the polynomial coefficient estimates (only included if `full=False` and `cov=True`) - See also + See Also -------- numpy.polyfit + numpy.polyval + xarray.polyval """ return self._to_temp_dataset().polyfit( dim, deg, skipna=skipna, rcond=rcond, w=w, full=full, cov=cov @@ -3845,19 +3974,17 @@ def pad( padded : DataArray DataArray with the padded coordinates and data. - See also + See Also -------- DataArray.shift, DataArray.roll, DataArray.bfill, DataArray.ffill, numpy.pad, dask.array.pad Notes ----- - By default when ``mode="constant"`` and ``constant_values=None``, integer types will be - promoted to ``float`` and padded with ``np.nan``. To avoid type promotion - specify ``constant_values=np.nan`` + For ``mode="constant"`` and ``constant_values=None``, integer types will be + promoted to ``float`` and padded with ``np.nan``. Examples -------- - >>> arr = xr.DataArray([5, 6, 7], coords=[("x", [0, 1, 2])]) >>> arr.pad(x=(1, 2), constant_values=0) @@ -3880,16 +4007,16 @@ def pad( * x (x) float64 nan 0.0 1.0 nan * y (y) int64 10 20 30 40 z (x) float64 nan 100.0 200.0 nan - >>> da.pad(x=1, constant_values=np.nan) + + Careful, ``constant_values`` are coerced to the data type of the array which may + lead to a loss of precision: + + >>> da.pad(x=1, constant_values=1.23456789) - array([[-9223372036854775808, -9223372036854775808, -9223372036854775808, - -9223372036854775808], - [ 0, 1, 2, - 3], - [ 10, 11, 12, - 13], - [-9223372036854775808, -9223372036854775808, -9223372036854775808, - -9223372036854775808]]) + array([[ 1, 1, 1, 1], + [ 0, 1, 2, 3], + [10, 11, 12, 13], + [ 1, 1, 1, 1]]) Coordinates: * x (x) float64 nan 0.0 1.0 nan * y (y) int64 10 20 30 40 @@ -3949,13 +4076,12 @@ def idxmin( New `DataArray` object with `idxmin` applied to its data and the indicated dimension removed. - See also + See Also -------- Dataset.idxmin, DataArray.idxmax, DataArray.min, DataArray.argmin Examples -------- - >>> array = xr.DataArray( ... [0, 2, 1, 0, -2], dims="x", coords={"x": ["a", "b", "c", "d", "e"]} ... ) @@ -4046,13 +4172,12 @@ def idxmax( New `DataArray` object with `idxmax` applied to its data and the indicated dimension removed. - See also + See Also -------- Dataset.idxmax, DataArray.idxmin, DataArray.max, DataArray.argmax Examples -------- - >>> array = xr.DataArray( ... [0, 2, 1, 0, -2], dims="x", coords={"x": ["a", "b", "c", "d", "e"]} ... ) @@ -4140,7 +4265,7 @@ def argmin( ------- result : DataArray or dict of DataArray - See also + See Also -------- Variable.argmin, DataArray.idxmin @@ -4243,7 +4368,7 @@ def argmax( ------- result : DataArray or dict of DataArray - See also + See Also -------- Variable.argmax, DataArray.idxmax @@ -4306,10 +4431,187 @@ def argmax( else: return self._replace_maybe_drop_dims(result) + def query( + self, + queries: Mapping[Hashable, Any] = None, + parser: str = "pandas", + engine: str = None, + missing_dims: str = "raise", + **queries_kwargs: Any, + ) -> "DataArray": + """Return a new data array indexed along the specified + dimension(s), where the indexers are given as strings containing + Python expressions to be evaluated against the values in the array. + + Parameters + ---------- + queries : dict, optional + A dict with keys matching dimensions and values given by strings + containing Python expressions to be evaluated against the data variables + in the dataset. The expressions will be evaluated using the pandas + eval() function, and can contain any valid Python expressions but cannot + contain any Python statements. + parser : {"pandas", "python"}, default: "pandas" + The parser to use to construct the syntax tree from the expression. + The default of 'pandas' parses code slightly different than standard + Python. Alternatively, you can parse an expression using the 'python' + parser to retain strict Python semantics. + engine : {"python", "numexpr", None}, default: None + The engine used to evaluate the expression. Supported engines are: + - None: tries to use numexpr, falls back to python + - "numexpr": evaluates expressions using numexpr + - "python": performs operations as if you had eval’d in top level python + missing_dims : {"raise", "warn", "ignore"}, default: "raise" + What to do if dimensions that should be selected from are not present in the + Dataset: + - "raise": raise an exception + - "warn": raise a warning, and ignore the missing dimensions + - "ignore": ignore the missing dimensions + **queries_kwargs : {dim: query, ...}, optional + The keyword arguments form of ``queries``. + One of queries or queries_kwargs must be provided. + + Returns + ------- + obj : DataArray + A new DataArray with the same contents as this dataset, indexed by + the results of the appropriate queries. + + See Also + -------- + DataArray.isel + Dataset.query + pandas.eval + + Examples + -------- + >>> da = xr.DataArray(np.arange(0, 5, 1), dims="x", name="a") + >>> da + + array([0, 1, 2, 3, 4]) + Dimensions without coordinates: x + >>> da.query(x="a > 2") + + array([3, 4]) + Dimensions without coordinates: x + """ + + ds = self._to_dataset_whole(shallow_copy=True) + ds = ds.query( + queries=queries, + parser=parser, + engine=engine, + missing_dims=missing_dims, + **queries_kwargs, + ) + return ds[self.name] + + def curvefit( + self, + coords: Union[Union[str, "DataArray"], Iterable[Union[str, "DataArray"]]], + func: Callable[..., Any], + reduce_dims: Union[Hashable, Iterable[Hashable]] = None, + skipna: bool = True, + p0: Dict[str, Any] = None, + bounds: Dict[str, Any] = None, + param_names: Sequence[str] = None, + kwargs: Dict[str, Any] = None, + ): + """ + Curve fitting optimization for arbitrary functions. + + Wraps `scipy.optimize.curve_fit` with `apply_ufunc`. + + Parameters + ---------- + coords : hashable, DataArray, or sequence of DataArray or hashable + Independent coordinate(s) over which to perform the curve fitting. Must share + at least one dimension with the calling object. When fitting multi-dimensional + functions, supply `coords` as a sequence in the same order as arguments in + `func`. To fit along existing dimensions of the calling object, `coords` can + also be specified as a str or sequence of strs. + func : callable + User specified function in the form `f(x, *params)` which returns a numpy + array of length `len(x)`. `params` are the fittable parameters which are optimized + by scipy curve_fit. `x` can also be specified as a sequence containing multiple + coordinates, e.g. `f((x0, x1), *params)`. + reduce_dims : hashable or sequence of hashable + Additional dimension(s) over which to aggregate while fitting. For example, + calling `ds.curvefit(coords='time', reduce_dims=['lat', 'lon'], ...)` will + aggregate all lat and lon points and fit the specified function along the + time dimension. + skipna : bool, optional + Whether to skip missing values when fitting. Default is True. + p0 : dict-like, optional + Optional dictionary of parameter names to initial guesses passed to the + `curve_fit` `p0` arg. If none or only some parameters are passed, the rest will + be assigned initial values following the default scipy behavior. + bounds : dict-like, optional + Optional dictionary of parameter names to bounding values passed to the + `curve_fit` `bounds` arg. If none or only some parameters are passed, the rest + will be unbounded following the default scipy behavior. + param_names : sequence of hashable, optional + Sequence of names for the fittable parameters of `func`. If not supplied, + this will be automatically determined by arguments of `func`. `param_names` + should be manually supplied when fitting a function that takes a variable + number of parameters. + **kwargs : optional + Additional keyword arguments to passed to scipy curve_fit. + + Returns + ------- + curvefit_results : Dataset + A single dataset which contains: + + [var]_curvefit_coefficients + The coefficients of the best fit. + [var]_curvefit_covariance + The covariance matrix of the coefficient estimates. + + See Also + -------- + DataArray.polyfit + scipy.optimize.curve_fit + """ + return self._to_temp_dataset().curvefit( + coords, + func, + reduce_dims=reduce_dims, + skipna=skipna, + p0=p0, + bounds=bounds, + param_names=param_names, + kwargs=kwargs, + ) + + def drop_duplicates( + self, + dim: Hashable, + keep: Union[ + str, + bool, + ] = "first", + ): + """Returns a new DataArray with duplicate dimension values removed. + + Parameters + ---------- + dim : dimension label, optional + keep : {"first", "last", False}, default: "first" + Determines which duplicates (if any) to keep. + - ``"first"`` : Drop duplicates except for the first occurrence. + - ``"last"`` : Drop duplicates except for the last occurrence. + - False : Drop all duplicates. + + Returns + ------- + DataArray + """ + if dim not in self.dims: + raise ValueError(f"'{dim}' not found in dimensions") + indexes = {dim: ~self.get_index(dim).duplicated(keep=keep)} + return self.isel(indexes) + # this needs to be at the end, or mypy will confuse with `str` # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names str = utils.UncachedAccessor(StringAccessor) - - -# priority most be higher than Variable to properly work with binary ufuncs -ops.inject_all_ops_and_reduce_methods(DataArray, priority=60) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 8376b4875f9..4bfc1ccbdf1 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1,10 +1,9 @@ import copy import datetime -import functools +import inspect import sys import warnings from collections import defaultdict -from distutils.version import LooseVersion from html import escape from numbers import Number from operator import methodcaller @@ -13,6 +12,7 @@ TYPE_CHECKING, Any, Callable, + Collection, DefaultDict, Dict, Hashable, @@ -52,11 +52,9 @@ weighted, ) from .alignment import _broadcast_helper, _get_broadcast_dims_map_common_coords, align -from .common import ( - DataWithCoords, - ImplementsDatasetReduce, - _contains_datetime_like_objects, -) +from .arithmetic import DatasetArithmetic +from .common import DataWithCoords, _contains_datetime_like_objects +from .computation import unify_chunks from .coordinates import ( DatasetCoordinates, assert_coordinate_consistent, @@ -64,7 +62,10 @@ ) from .duck_array_ops import datetime_to_numeric from .indexes import ( + Index, Indexes, + PandasIndex, + PandasMultiIndex, default_indexes, isel_variable_and_index, propagate_indexes, @@ -85,7 +86,7 @@ Default, Frozen, HybridMappingProxy, - SortedKeysDict, + OrderedSet, _default, decode_numpy_dict_values, drop_dims_from_indexers, @@ -197,16 +198,15 @@ def calculate_dimensions(variables: Mapping[Hashable, Variable]) -> Dict[Hashabl for dim, size in zip(var.dims, var.shape): if dim in scalar_vars: raise ValueError( - "dimension %r already exists as a scalar variable" % dim + f"dimension {dim!r} already exists as a scalar variable" ) if dim not in dims: dims[dim] = size last_used[dim] = k elif dims[dim] != size: raise ValueError( - "conflicting sizes for dimension %r: " - "length %s on %r and length %s on %r" - % (dim, size, k, dims[dim], last_used[dim]) + f"conflicting sizes for dimension {dim!r}: " + f"length {size} on {k!r} and length {dims[dim]} on {last_used!r}" ) return dims @@ -246,8 +246,7 @@ def merge_indexes( and var.dims != current_index_variable.dims ): raise ValueError( - "dimension mismatch between %r %s and %r %s" - % (dim, current_index_variable.dims, n, var.dims) + f"dimension mismatch between {dim!r} {current_index_variable.dims} and {n!r} {var.dims}" ) if current_index_variable is not None and append: @@ -257,7 +256,7 @@ def merge_indexes( codes.extend(current_index.codes) levels.extend(current_index.levels) else: - names.append("%s_level_0" % dim) + names.append(f"{dim}_level_0") cat = pd.Categorical(current_index.values, ordered=True) codes.append(cat.codes) levels.append(cat.categories) @@ -456,6 +455,63 @@ def as_dataset(obj: Any) -> "Dataset": return obj +def _get_func_args(func, param_names): + """Use `inspect.signature` to try accessing `func` args. Otherwise, ensure + they are provided by user. + """ + try: + func_args = inspect.signature(func).parameters + except ValueError: + func_args = {} + if not param_names: + raise ValueError( + "Unable to inspect `func` signature, and `param_names` was not provided." + ) + if param_names: + params = param_names + else: + params = list(func_args)[1:] + if any( + [(p.kind in [p.VAR_POSITIONAL, p.VAR_KEYWORD]) for p in func_args.values()] + ): + raise ValueError( + "`param_names` must be provided because `func` takes variable length arguments." + ) + return params, func_args + + +def _initialize_curvefit_params(params, p0, bounds, func_args): + """Set initial guess and bounds for curvefit. + Priority: 1) passed args 2) func signature 3) scipy defaults + """ + + def _initialize_feasible(lb, ub): + # Mimics functionality of scipy.optimize.minpack._initialize_feasible + lb_finite = np.isfinite(lb) + ub_finite = np.isfinite(ub) + p0 = np.nansum( + [ + 0.5 * (lb + ub) * int(lb_finite & ub_finite), + (lb + 1) * int(lb_finite & ~ub_finite), + (ub - 1) * int(~lb_finite & ub_finite), + ] + ) + return p0 + + param_defaults = {p: 1 for p in params} + bounds_defaults = {p: (-np.inf, np.inf) for p in params} + for p in params: + if p in func_args and func_args[p].default is not func_args[p].empty: + param_defaults[p] = func_args[p].default + if p in bounds: + bounds_defaults[p] = tuple(bounds[p]) + if param_defaults[p] < bounds[p][0] or param_defaults[p] > bounds[p][1]: + param_defaults[p] = _initialize_feasible(bounds[p][0], bounds[p][1]) + if p in p0: + param_defaults[p] = p0[p] + return param_defaults, bounds_defaults + + class DataVariables(Mapping[Hashable, "DataArray"]): __slots__ = ("_dataset",) @@ -489,7 +545,7 @@ def variables(self) -> Mapping[Hashable, Variable]: return Frozen({k: all_variables[k] for k in self}) def _ipython_key_completions_(self): - """Provide method for the key-autocompletions in IPython. """ + """Provide method for the key-autocompletions in IPython.""" return [ key for key in self._dataset._ipython_key_completions_() @@ -508,8 +564,19 @@ def __getitem__(self, key: Mapping[Hashable, Any]) -> "Dataset": raise TypeError("can only lookup dictionaries from Dataset.loc") return self.dataset.sel(key) + def __setitem__(self, key, value) -> None: + if not utils.is_dict_like(key): + raise TypeError( + "can only set locations defined by dictionaries from Dataset.loc." + f" Got: {key}" + ) + + # set new values + pos_indexers, _ = remap_label_indexers(self.dataset, key) + self.dataset[pos_indexers] = value + -class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords): +class Dataset(DataWithCoords, DatasetArithmetic, Mapping): """A multi-dimensional, in memory, array database. A dataset resembles an in-memory representation of a NetCDF file, @@ -601,7 +668,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords): ... ) >>> ds - Dimensions: (time: 3, x: 2, y: 2) + Dimensions: (x: 2, y: 2, time: 3) Coordinates: lon (x, y) float64 -99.83 -99.32 -99.79 -99.23 lat (x, y) float64 42.25 42.21 42.63 42.59 @@ -638,7 +705,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords): _dims: Dict[Hashable, int] _encoding: Optional[Dict[Hashable, Any]] _close: Optional[Callable[[], None]] - _indexes: Optional[Dict[Hashable, pd.Index]] + _indexes: Optional[Dict[Hashable, Index]] _variables: Dict[Hashable, Variable] __slots__ = ( @@ -677,8 +744,7 @@ def __init__( both_data_and_coords = set(data_vars) & set(coords) if both_data_and_coords: raise ValueError( - "variables %r are found in both data_vars and coords" - % both_data_and_coords + f"variables {both_data_and_coords!r} are found in both data_vars and coords" ) if isinstance(coords, Dataset): @@ -751,7 +817,7 @@ def dims(self) -> Mapping[Hashable, int]: See `Dataset.sizes` and `DataArray.sizes` for consistently named properties. """ - return Frozen(SortedKeysDict(self._dims)) + return Frozen(self._dims) @property def sizes(self) -> Mapping[Hashable, int]: @@ -762,7 +828,7 @@ def sizes(self) -> Mapping[Hashable, int]: This is an alias for `Dataset.dims` provided for the benefit of consistency with `DataArray.sizes`. - See also + See Also -------- DataArray.sizes """ @@ -863,16 +929,25 @@ def __dask_scheduler__(self): return da.Array.__dask_scheduler__ def __dask_postcompute__(self): + return self._dask_postcompute, () + + def __dask_postpersist__(self): + return self._dask_postpersist, () + + def _dask_postcompute(self, results: "Iterable[Variable]") -> "Dataset": import dask - info = [ - (True, k, v.__dask_postcompute__()) - if dask.is_dask_collection(v) - else (False, k, v) - for k, v in self._variables.items() - ] - args = ( - info, + variables = {} + results_iter = iter(results) + + for k, v in self._variables.items(): + if dask.is_dask_collection(v): + rebuild, args = v.__dask_postcompute__() + v = rebuild(next(results_iter), *args) + variables[k] = v + + return Dataset._construct_direct( + variables, self._coord_names, self._dims, self._attrs, @@ -880,19 +955,50 @@ def __dask_postcompute__(self): self._encoding, self._close, ) - return self._dask_postcompute, args - def __dask_postpersist__(self): - import dask + def _dask_postpersist( + self, dsk: Mapping, *, rename: Mapping[str, str] = None + ) -> "Dataset": + from dask import is_dask_collection + from dask.highlevelgraph import HighLevelGraph + from dask.optimization import cull - info = [ - (True, k, v.__dask_postpersist__()) - if dask.is_dask_collection(v) - else (False, k, v) - for k, v in self._variables.items() - ] - args = ( - info, + variables = {} + + for k, v in self._variables.items(): + if not is_dask_collection(v): + variables[k] = v + continue + + if isinstance(dsk, HighLevelGraph): + # dask >= 2021.3 + # __dask_postpersist__() was called by dask.highlevelgraph. + # Don't use dsk.cull(), as we need to prevent partial layers: + # https://github.com/dask/dask/issues/7137 + layers = v.__dask_layers__() + if rename: + layers = [rename.get(k, k) for k in layers] + dsk2 = dsk.cull_layers(layers) + elif rename: # pragma: nocover + # At the moment of writing, this is only for forward compatibility. + # replace_name_in_key requires dask >= 2021.3. + from dask.base import flatten, replace_name_in_key + + keys = [ + replace_name_in_key(k, rename) for k in flatten(v.__dask_keys__()) + ] + dsk2, _ = cull(dsk, keys) + else: + # __dask_postpersist__() was called by dask.optimize or dask.persist + dsk2, _ = cull(dsk, v.__dask_keys__()) + + rebuild, args = v.__dask_postpersist__() + # rename was added in dask 2021.3 + kwargs = {"rename": rename} if rename else {} + variables[k] = rebuild(dsk2, *args, **kwargs) + + return Dataset._construct_direct( + variables, self._coord_names, self._dims, self._attrs, @@ -900,45 +1006,6 @@ def __dask_postpersist__(self): self._encoding, self._close, ) - return self._dask_postpersist, args - - @staticmethod - def _dask_postcompute(results, info, *args): - variables = {} - results2 = list(results[::-1]) - for is_dask, k, v in info: - if is_dask: - func, args2 = v - r = results2.pop() - result = func(r, *args2) - else: - result = v - variables[k] = result - - final = Dataset._construct_direct(variables, *args) - return final - - @staticmethod - def _dask_postpersist(dsk, info, *args): - variables = {} - # postpersist is called in both dask.optimize and dask.persist - # When persisting, we want to filter out unrelated keys for - # each Variable's task graph. - is_persist = len(dsk) == len(info) - for is_dask, k, v in info: - if is_dask: - func, args2 = v - if is_persist: - name = args2[1][0] - dsk2 = {k: v for k, v in dsk.items() if k[0] == name} - else: - dsk2 = dsk - result = func(dsk2, *args2) - else: - result = v - variables[k] = result - - return Dataset._construct_direct(variables, *args) def compute(self, **kwargs) -> "Dataset": """Manually trigger loading and/or computation of this dataset's data @@ -1032,7 +1099,7 @@ def _replace( coord_names: Set[Hashable] = None, dims: Dict[Any, int] = None, attrs: Union[Dict[Hashable, Any], None, Default] = _default, - indexes: Union[Dict[Any, pd.Index], None, Default] = _default, + indexes: Union[Dict[Any, Index], None, Default] = _default, encoding: Union[dict, None, Default] = _default, inplace: bool = False, ) -> "Dataset": @@ -1081,7 +1148,7 @@ def _replace_with_new_dims( variables: Dict[Hashable, Variable], coord_names: set = None, attrs: Union[Dict[Hashable, Any], None, Default] = _default, - indexes: Union[Dict[Hashable, pd.Index], None, Default] = _default, + indexes: Union[Dict[Hashable, Index], None, Default] = _default, inplace: bool = False, ) -> "Dataset": """Replace variables with recalculated dimensions.""" @@ -1109,22 +1176,23 @@ def _replace_vars_and_dims( variables, coord_names, dims, attrs, indexes=None, inplace=inplace ) - def _overwrite_indexes(self, indexes: Mapping[Any, pd.Index]) -> "Dataset": + def _overwrite_indexes(self, indexes: Mapping[Any, Index]) -> "Dataset": if not indexes: return self variables = self._variables.copy() - new_indexes = dict(self.indexes) + new_indexes = dict(self.xindexes) for name, idx in indexes.items(): - variables[name] = IndexVariable(name, idx) + variables[name] = IndexVariable(name, idx.to_pandas_index()) new_indexes[name] = idx obj = self._replace(variables, indexes=new_indexes) # switch from dimension to level names, if necessary dim_names: Dict[Hashable, str] = {} for dim, idx in indexes.items(): - if not isinstance(idx, pd.MultiIndex) and idx.name != dim: - dim_names[dim] = idx.name + pd_idx = idx.to_pandas_index() + if not isinstance(pd_idx, pd.MultiIndex) and pd_idx.name != dim: + dim_names[dim] = pd_idx.name if dim_names: obj = obj.rename(dim_names) return obj @@ -1159,7 +1227,6 @@ def copy(self, deep: bool = False, data: Mapping = None) -> "Dataset": Examples -------- - Shallow copy versus deep copy >>> da = xr.DataArray(np.random.randn(2, 3)) @@ -1255,15 +1322,29 @@ def copy(self, deep: bool = False, data: Mapping = None) -> "Dataset": return self._replace(variables, attrs=attrs) + def as_numpy(self: "Dataset") -> "Dataset": + """ + Coerces wrapped data and coordinates into numpy arrays, returning a Dataset. + + See also + -------- + DataArray.as_numpy + DataArray.to_numpy : Returns only the data as a numpy.ndarray object. + """ + numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()} + return self._replace(variables=numpy_variables) + @property def _level_coords(self) -> Dict[str, Hashable]: """Return a mapping of all MultiIndex levels and their corresponding coordinate name. """ level_coords: Dict[str, Hashable] = {} - for name, index in self.indexes.items(): - if isinstance(index, pd.MultiIndex): - level_names = index.names + for name, index in self.xindexes.items(): + # TODO: benbovy - flexible indexes: update when MultIndex has its own xarray class. + pd_index = index.to_pandas_index() + if isinstance(pd_index, pd.MultiIndex): + level_names = pd_index.names (dim,) = self.variables[name].dims level_coords.update({lname: dim for lname in level_names}) return level_coords @@ -1274,7 +1355,7 @@ def _copy_listed(self, names: Iterable[Hashable]) -> "Dataset": """ variables: Dict[Hashable, Variable] = {} coord_names = set() - indexes: Dict[Hashable, pd.Index] = {} + indexes: Dict[Hashable, Index] = {} for name in names: try: @@ -1287,9 +1368,9 @@ def _copy_listed(self, names: Iterable[Hashable]) -> "Dataset": if ref_name in self._coord_names or ref_name in self.dims: coord_names.add(var_name) if (var_name,) == var.dims: - indexes[var_name] = var.to_index() + indexes[var_name] = var._to_xindex() - needed_dims: Set[Hashable] = set() + needed_dims: OrderedSet[Hashable] = OrderedSet() for v in variables.values(): needed_dims.update(v.dims) @@ -1303,8 +1384,8 @@ def _copy_listed(self, names: Iterable[Hashable]) -> "Dataset": if set(self.variables[k].dims) <= needed_dims: variables[k] = self._variables[k] coord_names.add(k) - if k in self.indexes: - indexes[k] = self.indexes[k] + if k in self.xindexes: + indexes[k] = self.xindexes[k] return self._replace(variables, coord_names, dims, indexes=indexes) @@ -1396,11 +1477,11 @@ def loc(self) -> _LocIndexer: # FIXME https://github.com/python/mypy/issues/7328 @overload - def __getitem__(self, key: Mapping) -> "Dataset": # type: ignore + def __getitem__(self, key: Mapping) -> "Dataset": # type: ignore[misc] ... @overload - def __getitem__(self, key: Hashable) -> "DataArray": # type: ignore + def __getitem__(self, key: Hashable) -> "DataArray": # type: ignore[misc] ... @overload @@ -1419,38 +1500,142 @@ def __getitem__(self, key): if hashable(key): return self._construct_dataarray(key) else: - return self._copy_listed(np.asarray(key)) + return self._copy_listed(key) - def __setitem__(self, key: Hashable, value) -> None: + def __setitem__(self, key: Union[Hashable, List[Hashable], Mapping], value) -> None: """Add an array to this dataset. + Multiple arrays can be added at the same time, in which case each of + the following operations is applied to the respective value. + + If key is a dictionary, update all variables in the dataset + one by one with the given value at the given location. + If the given value is also a dataset, select corresponding variables + in the given value and in the dataset to be changed. If value is a `DataArray`, call its `select_vars()` method, rename it to `key` and merge the contents of the resulting dataset into this dataset. - If value is an `Variable` object (or tuple of form + If value is a `Variable` object (or tuple of form ``(dims, data[, attrs])``), add it to this dataset as a new variable. """ if utils.is_dict_like(key): - raise NotImplementedError( - "cannot yet use a dictionary as a key to set Dataset values" + # check for consistency and convert value to dataset + value = self._setitem_check(key, value) + # loop over dataset variables and set new values + processed = [] + for name, var in self.items(): + try: + var[key] = value[name] + processed.append(name) + except Exception as e: + if processed: + raise RuntimeError( + "An error occured while setting values of the" + f" variable '{name}'. The following variables have" + f" been successfully updated:\n{processed}" + ) from e + else: + raise e + + elif isinstance(key, list): + if len(key) == 0: + raise ValueError("Empty list of variables to be set") + if len(key) == 1: + self.update({key[0]: value}) + else: + if len(key) != len(value): + raise ValueError( + f"Different lengths of variables to be set " + f"({len(key)}) and data used as input for " + f"setting ({len(value)})" + ) + if isinstance(value, Dataset): + self.update(dict(zip(key, value.data_vars.values()))) + elif isinstance(value, xr.DataArray): + raise ValueError("Cannot assign single DataArray to multiple keys") + else: + self.update(dict(zip(key, value))) + + else: + self.update({key: value}) + + def _setitem_check(self, key, value): + """Consistency check for __setitem__ + + When assigning values to a subset of a Dataset, do consistency check beforehand + to avoid leaving the dataset in a partially updated state when an error occurs. + """ + from .dataarray import DataArray + + if isinstance(value, Dataset): + missing_vars = [ + name for name in value.data_vars if name not in self.data_vars + ] + if missing_vars: + raise ValueError( + f"Variables {missing_vars} in new values" + f" not available in original dataset:\n{self}" + ) + elif not any([isinstance(value, t) for t in [DataArray, Number, str]]): + raise TypeError( + "Dataset assignment only accepts DataArrays, Datasets, and scalars." ) - self.update({key: value}) + new_value = xr.Dataset() + for name, var in self.items(): + # test indexing + try: + var_k = var[key] + except Exception as e: + raise ValueError( + f"Variable '{name}': indexer {key} not available" + ) from e + + if isinstance(value, Dataset): + val = value[name] + else: + val = value + + if isinstance(val, DataArray): + # check consistency of dimensions + for dim in val.dims: + if dim not in var_k.dims: + raise KeyError( + f"Variable '{name}': dimension '{dim}' appears in new values " + f"but not in the indexed original data" + ) + dims = tuple([dim for dim in var_k.dims if dim in val.dims]) + if dims != val.dims: + raise ValueError( + f"Variable '{name}': dimension order differs between" + f" original and new data:\n{dims}\nvs.\n{val.dims}" + ) + else: + val = np.array(val) + + # type conversion + new_value[name] = val.astype(var_k.dtype, copy=False) + + # check consistency of dimension sizes and dimension coordinates + if isinstance(value, DataArray) or isinstance(value, Dataset): + xr.align(self[key], value, join="exact", copy=False) + + return new_value def __delitem__(self, key: Hashable) -> None: """Remove a variable from this dataset.""" del self._variables[key] self._coord_names.discard(key) - if key in self.indexes: + if key in self.xindexes: assert self._indexes is not None del self._indexes[key] self._dims = calculate_dimensions(self._variables) # mutable objects should not be hashable # https://github.com/python/mypy/issues/4266 - __hash__ = None # type: ignore + __hash__ = None # type: ignore[assignment] def _all_compat(self, other: "Dataset", compat_str: str) -> bool: """Helper function for equals and identical""" @@ -1520,7 +1705,21 @@ def identical(self, other: "Dataset") -> bool: @property def indexes(self) -> Indexes: - """Mapping of pandas.Index objects used for label based indexing""" + """Mapping of pandas.Index objects used for label based indexing. + + Raises an error if this Dataset has indexes that cannot be coerced + to pandas.Index objects. + + See Also + -------- + Dataset.xindexes + + """ + return Indexes({k: idx.to_pandas_index() for k, idx in self.xindexes.items()}) + + @property + def xindexes(self) -> Indexes: + """Mapping of xarray Index objects used for label based indexing.""" if self._indexes is None: self._indexes = default_indexes(self._variables, self._dims) return Indexes(self._indexes) @@ -1549,7 +1748,7 @@ def set_coords(self, names: "Union[Hashable, Iterable[Hashable]]") -> "Dataset": ------- Dataset - See also + See Also -------- Dataset.swap_dims """ @@ -1597,7 +1796,7 @@ def reset_coords( bad_coords = set(names) & set(self.dims) if bad_coords: raise ValueError( - "cannot remove index coordinates with reset_coords: %s" % bad_coords + f"cannot remove index coordinates with reset_coords: {bad_coords}" ) obj = self.copy() obj._coord_names.difference_update(names) @@ -1719,15 +1918,25 @@ def to_zarr( group: str = None, encoding: Mapping = None, compute: bool = True, - consolidated: bool = False, + consolidated: Optional[bool] = None, append_dim: Hashable = None, region: Mapping[str, slice] = None, + safe_chunks: bool = True, ) -> "ZarrStore": """Write dataset contents to a zarr group. - .. note:: Experimental - The Zarr backend is new and experimental. Please report any - unexpected behavior via github issues. + Zarr chunks are determined in the following way: + + - From the ``chunks`` attribute in each variable's ``encoding`` + - If the variable is a Dask array, from the dask chunks + - If neither Dask chunks nor encoding chunks are present, chunks will + be determined automatically by Zarr + - If both Dask chunks and encoding chunks are present, encoding chunks + will be used, provided that there is a many-to-one relationship between + encoding chunks and dask chunks (i.e. Dask chunks are bigger than and + evenly divide encoding chunks); otherwise raise a ``ValueError``. + This restriction ensures that no synchronization / locks are required + when writing. To disable this restriction, use ``safe_chunks=False``. Parameters ---------- @@ -1736,13 +1945,14 @@ def to_zarr( chunk_store : MutableMapping, str or Path, optional Store or path to directory in file system only for Zarr array chunks. Requires zarr-python v2.4.0 or later. - mode : {"w", "w-", "a", None}, optional + mode : {"w", "w-", "a", "r+", None}, optional Persistence mode: "w" means create (overwrite if exists); "w-" means create (fail if exists); - "a" means override existing variables (create if does not exist). - If ``append_dim`` is set, ``mode`` can be omitted as it is - internally set to ``"a"``. Otherwise, ``mode`` will default to - `w-` if not set. + "a" means override existing variables (create if does not exist); + "r+" means modify existing array *values* only (raise an error if + any metadata or shapes would change). + The default mode is "a" if ``append_dim`` is set. Otherwise, it is + "r+" if ``region`` is set and ``w-`` otherwise. synchronizer : object, optional Zarr array synchronizer. group : str, optional @@ -1751,17 +1961,20 @@ def to_zarr( Nested dictionary with variable names as keys and dictionaries of variable specific encodings as values, e.g., ``{"my_variable": {"dtype": "int16", "scale_factor": 0.1,}, ...}`` - compute: bool, optional + compute : bool, optional If True write array data immediately, otherwise return a ``dask.delayed.Delayed`` object that can be computed to write array data later. Metadata is always updated eagerly. - consolidated: bool, optional + consolidated : bool, optional If True, apply zarr's `consolidate_metadata` function to the store - after writing metadata. - append_dim: hashable, optional + after writing metadata and read existing stores with consolidated + metadata; if False, do not. The default (`consolidated=None`) means + write consolidated metadata and attempt to read consolidated + metadata for existing stores (falling back to non-consolidated). + append_dim : hashable, optional If set, the dimension along which the data will be appended. All other dimensions on overriden variables must remain the same size. - region: dict, optional + region : dict, optional Optional mapping from dimension names to integer slices along dataset dimensions to indicate the region of existing zarr array(s) in which to write this dataset's data. For example, @@ -1779,6 +1992,13 @@ def to_zarr( in with ``region``, use a separate call to ``to_zarr()`` with ``compute=False``. See "Appending to existing Zarr stores" in the reference documentation for full details. + safe_chunks : bool, optional + If True, only allow writes to when there is a many-to-one relationship + between Zarr chunks (specified in encoding) and Dask chunks. + Set False to override this restriction; however, data may become corrupted + if Zarr arrays are written in parallel. This option may be useful in combination + with ``compute=False`` to initialize a Zarr from an existing + Dataset with aribtrary chunk structure. References ---------- @@ -1792,6 +2012,15 @@ def to_zarr( If a DataArray is a dask array, it is written with those chunks. If not other chunks are found, Zarr uses its own heuristics to choose automatic chunk sizes. + + encoding: + The encoding attribute (if exists) of the DataArray(s) will be + used. Override any existing encodings by providing the ``encoding`` kwarg. + + See Also + -------- + :ref:`io.zarr` + The I/O user guide, with more details and examples. """ from ..backends.api import to_zarr @@ -1810,6 +2039,7 @@ def to_zarr( consolidated=consolidated, append_dim=append_dim, region=region, + safe_chunks=safe_chunks, ) def __repr__(self) -> str: @@ -1832,7 +2062,7 @@ def info(self, buf=None) -> None: See Also -------- pandas.DataFrame.assign - ncdump: netCDF's ncdump + ncdump : netCDF's ncdump """ if buf is None: # pragma: no cover buf = sys.stdout @@ -1870,14 +2100,14 @@ def chunks(self) -> Mapping[Hashable, Tuple[int, ...]]: "This can be fixed by calling unify_chunks()." ) chunks[dim] = c - return Frozen(SortedKeysDict(chunks)) + return Frozen(chunks) def chunk( self, chunks: Union[ - Number, + int, str, - Mapping[Hashable, Union[None, Number, str, Tuple[Number, ...]]], + Mapping[Hashable, Union[None, int, str, Tuple[int, ...]]], ] = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) name_prefix: str = "xarray-", token: str = None, @@ -1918,13 +2148,13 @@ def chunk( ) chunks = {} - if isinstance(chunks, (Number, str)): + if isinstance(chunks, (Number, str, int)): chunks = dict.fromkeys(self.dims, chunks) bad_dims = chunks.keys() - self.dims.keys() if bad_dims: raise ValueError( - "some chunks keys are not dimensions on this " "object: %s" % bad_dims + f"some chunks keys are not dimensions on this object: {bad_dims}" ) variables = { @@ -1962,7 +2192,9 @@ def _validate_indexers( v = np.asarray(v) if v.dtype.kind in "US": - index = self.indexes[k] + # TODO: benbovy - flexible indexes + # update when CFTimeIndex has its own xarray index class + index = self.xindexes[k].to_pandas_index() if isinstance(index, pd.DatetimeIndex): v = v.astype("datetime64[ns]") elif isinstance(index, xr.CFTimeIndex): @@ -1986,12 +2218,12 @@ def _validate_interp_indexers( else: yield k, v elif isinstance(v, int): - yield k, Variable((), v) + yield k, Variable((), v, attrs=self.coords[k].attrs) elif isinstance(v, np.ndarray): if v.ndim == 0: - yield k, Variable((), v) + yield k, Variable((), v, attrs=self.coords[k].attrs) elif v.ndim == 1: - yield k, IndexVariable((k,), v) + yield k, IndexVariable((k,), v, attrs=self.coords[k].attrs) else: raise AssertionError() # Already tested by _validate_indexers else: @@ -2111,7 +2343,7 @@ def isel( continue if indexes and var_name in indexes: if var_value.ndim == 1: - indexes[var_name] = var_value.to_index() + indexes[var_name] = var_value._to_xindex() else: del indexes[var_name] variables[var_name] = var_value @@ -2139,16 +2371,16 @@ def _isel_fancy( indexers_list = list(self._validate_indexers(indexers, missing_dims)) variables: Dict[Hashable, Variable] = {} - indexes: Dict[Hashable, pd.Index] = {} + indexes: Dict[Hashable, Index] = {} for name, var in self.variables.items(): var_indexers = {k: v for k, v in indexers_list if k in var.dims} if drop and name in var_indexers: continue # drop this variable - if name in self.indexes: + if name in self.xindexes: new_var, new_index = isel_variable_and_index( - name, var, self.indexes[name], var_indexers + name, var, self.xindexes[name], var_indexers ) if new_index is not None: indexes[name] = new_index @@ -2232,7 +2464,6 @@ def sel( in this dataset, unless vectorized indexing was triggered by using an array indexer, in which case the data will be a copy. - See Also -------- Dataset.isel @@ -2242,6 +2473,10 @@ def sel( pos_indexers, new_indexes = remap_label_indexers( self, indexers=indexers, method=method, tolerance=tolerance ) + # TODO: benbovy - flexible indexes: also use variables returned by Index.query + # (temporary dirty fix). + new_indexes = {k: v[0] for k, v in new_indexes.items()} + result = self.isel(indexers=pos_indexers, drop=drop) return result._overwrite_indexes(new_indexes) @@ -2263,7 +2498,6 @@ def head( The keyword arguments form of ``indexers``. One of indexers or indexers_kwargs must be provided. - See Also -------- Dataset.tail @@ -2282,12 +2516,12 @@ def head( if not isinstance(v, int): raise TypeError( "expected integer type indexer for " - "dimension %r, found %r" % (k, type(v)) + f"dimension {k!r}, found {type(v)!r}" ) elif v < 0: raise ValueError( "expected positive integer as indexer " - "for dimension %r, found %s" % (k, v) + f"for dimension {k!r}, found {v}" ) indexers_slices = {k: slice(val) for k, val in indexers.items()} return self.isel(indexers_slices) @@ -2310,7 +2544,6 @@ def tail( The keyword arguments form of ``indexers``. One of indexers or indexers_kwargs must be provided. - See Also -------- Dataset.head @@ -2329,12 +2562,12 @@ def tail( if not isinstance(v, int): raise TypeError( "expected integer type indexer for " - "dimension %r, found %r" % (k, type(v)) + f"dimension {k!r}, found {type(v)!r}" ) elif v < 0: raise ValueError( "expected positive integer as indexer " - "for dimension %r, found %s" % (k, v) + f"for dimension {k!r}, found {v}" ) indexers_slices = { k: slice(-val, None) if val != 0 else slice(val) @@ -2360,7 +2593,6 @@ def thin( The keyword arguments form of ``indexers``. One of indexers or indexers_kwargs must be provided. - See Also -------- Dataset.head @@ -2380,12 +2612,12 @@ def thin( if not isinstance(v, int): raise TypeError( "expected integer type indexer for " - "dimension %r, found %r" % (k, type(v)) + f"dimension {k!r}, found {type(v)!r}" ) elif v < 0: raise ValueError( "expected positive integer as indexer " - "for dimension %r, found %s" % (k, v) + f"for dimension {k!r}, found {v}" ) elif v == 0: raise ValueError("step cannot be zero") @@ -2536,11 +2768,8 @@ def reindex( Examples -------- - Create a dataset with some fictional data. - >>> import xarray as xr - >>> import pandas as pd >>> x = xr.Dataset( ... { ... "temperature": ("station", 20 * np.random.rand(4)), @@ -2707,12 +2936,12 @@ def _reindex( bad_dims = [d for d in indexers if d not in self.dims] if bad_dims: - raise ValueError("invalid reindex dimensions: %s" % bad_dims) + raise ValueError(f"invalid reindex dimensions: {bad_dims}") variables, indexes = alignment.reindex_variables( self.variables, self.sizes, - self.indexes, + self.xindexes, indexers, method, tolerance, @@ -2730,6 +2959,7 @@ def interp( method: str = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] = None, + method_non_numeric: str = "nearest", **coords_kwargs: Any, ) -> "Dataset": """Multidimensional interpolation of Dataset. @@ -2750,10 +2980,13 @@ def interp( in any order and they are sorted first. If True, interpolated coordinates are assumed to be an array of monotonically increasing values. - kwargs: dict, optional + kwargs : dict, optional Additional keyword arguments passed to scipy's interpolator. Valid options and their behavior depend on if 1-dimensional or multi-dimensional interpolation is used. + method_non_numeric : {"nearest", "pad", "ffill", "backfill", "bfill"}, optional + Method for non-numeric types. Passed on to :py:meth:`Dataset.reindex`. + ``"nearest"`` is used by default. **coords_kwargs : {dim: coordinate, ...}, optional The keyword arguments form of ``coords``. One of coords or coords_kwargs must be provided. @@ -2890,34 +3123,81 @@ def _validate_interp_indexer(x, new_x): ) return x, new_x + validated_indexers = { + k: _validate_interp_indexer(maybe_variable(obj, k), v) + for k, v in indexers.items() + } + + # optimization: subset to coordinate range of the target index + if method in ["linear", "nearest"]: + for k, v in validated_indexers.items(): + obj, newidx = missing._localize(obj, {k: v}) + validated_indexers[k] = newidx[k] + + # optimization: create dask coordinate arrays once per Dataset + # rather than once per Variable when dask.array.unify_chunks is called later + # GH4739 + if obj.__dask_graph__(): + dask_indexers = { + k: (index.to_base_variable().chunk(), dest.to_base_variable().chunk()) + for k, (index, dest) in validated_indexers.items() + } + variables: Dict[Hashable, Variable] = {} + to_reindex: Dict[Hashable, Variable] = {} for name, var in obj._variables.items(): if name in indexers: continue - if var.dtype.kind in "uifc": - var_indexers = { - k: _validate_interp_indexer(maybe_variable(obj, k), v) - for k, v in indexers.items() - if k in var.dims - } + if is_duck_dask_array(var.data): + use_indexers = dask_indexers + else: + use_indexers = validated_indexers + + dtype_kind = var.dtype.kind + if dtype_kind in "uifc": + # For normal number types do the interpolation: + var_indexers = {k: v for k, v in use_indexers.items() if k in var.dims} variables[name] = missing.interp(var, var_indexers, method, **kwargs) + elif dtype_kind in "ObU" and (use_indexers.keys() & var.dims): + # For types that we do not understand do stepwise + # interpolation to avoid modifying the elements. + # Use reindex_variables instead because it supports + # booleans and objects and retains the dtype but inside + # this loop there might be some duplicate code that slows it + # down, therefore collect these signals and run it later: + to_reindex[name] = var elif all(d not in indexers for d in var.dims): - # keep unrelated object array + # For anything else we can only keep variables if they + # are not dependent on any coords that are being + # interpolated along: variables[name] = var + if to_reindex: + # Reindex variables: + variables_reindex = alignment.reindex_variables( + variables=to_reindex, + sizes=obj.sizes, + indexes=obj.xindexes, + indexers={k: v[-1] for k, v in validated_indexers.items()}, + method=method_non_numeric, + )[0] + variables.update(variables_reindex) + + # Get the coords that also exist in the variables: coord_names = obj._coord_names & variables.keys() - indexes = {k: v for k, v in obj.indexes.items() if k not in indexers} + # Get the indexes that are not being interpolated along: + indexes = {k: v for k, v in obj.xindexes.items() if k not in indexers} selected = self._replace_with_new_dims( variables.copy(), coord_names, indexes=indexes ) - # attach indexer as coordinate + # Attach indexer as coordinate variables.update(indexers) for k, v in indexers.items(): assert isinstance(v, Variable) if v.dims == (k,): - indexes[k] = v.to_index() + indexes[k] = v._to_xindex() # Extract coordinates from indexers coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(coords) @@ -2933,6 +3213,7 @@ def interp_like( method: str = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] = None, + method_non_numeric: str = "nearest", ) -> "Dataset": """Interpolate this object onto the coordinates of another object, filling the out of range values with NaN. @@ -2952,8 +3233,11 @@ def interp_like( in any order and they are sorted first. If True, interpolated coordinates are assumed to be an array of monotonically increasing values. - kwargs: dict, optional + kwargs : dict, optional Additional keyword passed to scipy's interpolator. + method_non_numeric : {"nearest", "pad", "ffill", "backfill", "bfill"}, optional + Method for non-numeric types. Passed on to :py:meth:`Dataset.reindex`. + ``"nearest"`` is used by default. Returns ------- @@ -2989,7 +3273,13 @@ def interp_like( # We do not support interpolation along object coordinate. # reindex instead. ds = self.reindex(object_coords) - return ds.interp(numeric_coords, method, assume_sorted, kwargs) + return ds.interp( + coords=numeric_coords, + method=method, + assume_sorted=assume_sorted, + kwargs=kwargs, + method_non_numeric=method_non_numeric, + ) # Helper methods for rename() def _rename_vars(self, name_dict, dims_dict): @@ -3010,6 +3300,7 @@ def _rename_dims(self, name_dict): return {name_dict.get(k, k): v for k, v in self.dims.items()} def _rename_indexes(self, name_dict, dims_set): + # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5645 if self._indexes is None: return None indexes = {} @@ -3019,10 +3310,11 @@ def _rename_indexes(self, name_dict, dims_set): continue if isinstance(v, pd.MultiIndex): new_names = [name_dict.get(k, k) for k in v.names] - index = v.rename(names=new_names) + indexes[new_name] = PandasMultiIndex( + v.rename(names=new_names), new_name + ) else: - index = v.rename(new_name) - indexes[new_name] = index + indexes[new_name] = PandasIndex(v.rename(new_name), new_name) return indexes def _rename_all(self, name_dict, dims_dict): @@ -3063,8 +3355,8 @@ def rename( for k in name_dict.keys(): if k not in self and k not in self.dims: raise ValueError( - "cannot rename %r because it is not a " - "variable or dimension in this dataset" % k + f"cannot rename {k!r} because it is not a " + "variable or dimension in this dataset" ) variables, coord_names, dims, indexes = self._rename_all( @@ -3104,8 +3396,8 @@ def rename_dims( for k, v in dims_dict.items(): if k not in self.dims: raise ValueError( - "cannot rename %r because it is not a " - "dimension in this dataset" % k + f"cannot rename {k!r} because it is not a " + "dimension in this dataset" ) if v in self.dims or v in self: raise ValueError( @@ -3148,8 +3440,8 @@ def rename_vars( for k in name_dict: if k not in self: raise ValueError( - "cannot rename %r because it is not a " - "variable or coordinate in this dataset" % k + f"cannot rename {k!r} because it is not a " + "variable or coordinate in this dataset" ) variables, coord_names, dims, indexes = self._rename_all( name_dict=name_dict, dims_dict={} @@ -3166,8 +3458,7 @@ def swap_dims( dims_dict : dict-like Dictionary whose keys are current dimension names and whose values are new names. - - **dim_kwargs : {existing_dim: new_dim, ...}, optional + **dims_kwargs : {existing_dim: new_dim, ...}, optional The keyword arguments form of ``dims_dict``. One of dims_dict or dims_kwargs must be provided. @@ -3215,7 +3506,6 @@ def swap_dims( See Also -------- - Dataset.rename DataArray.swap_dims """ @@ -3226,13 +3516,13 @@ def swap_dims( for k, v in dims_dict.items(): if k not in self.dims: raise ValueError( - "cannot swap from dimension %r because it is " - "not an existing dimension" % k + f"cannot swap from dimension {k!r} because it is " + "not an existing dimension" ) if v in self.variables and self.variables[v].dims != (k,): raise ValueError( - "replacement dimension %r is not a 1D " - "variable along the old dimension %r" % (v, k) + f"replacement dimension {v!r} is not a 1D " + f"variable along the old dimension {k!r}" ) result_dims = {dims_dict.get(dim, dim) for dim in self.dims} @@ -3241,19 +3531,22 @@ def swap_dims( coord_names.update({dim for dim in dims_dict.values() if dim in self.variables}) variables: Dict[Hashable, Variable] = {} - indexes: Dict[Hashable, pd.Index] = {} + indexes: Dict[Hashable, Index] = {} for k, v in self.variables.items(): dims = tuple(dims_dict.get(dim, dim) for dim in v.dims) if k in result_dims: var = v.to_index_variable() - if k in self.indexes: - indexes[k] = self.indexes[k] + if k in self.xindexes: + indexes[k] = self.xindexes[k] else: new_index = var.to_index() if new_index.nlevels == 1: # make sure index name matches dimension name new_index = new_index.rename(k) - indexes[k] = new_index + if isinstance(new_index, pd.MultiIndex): + indexes[k] = PandasMultiIndex(new_index, k) + else: + indexes[k] = PandasIndex(new_index, k) else: var = v.to_base_variable() var.dims = dims @@ -3516,15 +3809,17 @@ def reorder_levels( """ dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, "reorder_levels") variables = self._variables.copy() - indexes = dict(self.indexes) + indexes = dict(self.xindexes) for dim, order in dim_order.items(): coord = self._variables[dim] - index = self.indexes[dim] + # TODO: benbovy - flexible indexes: update when MultiIndex + # has its own class inherited from xarray.Index + index = self.xindexes[dim].to_pandas_index() if not isinstance(index, pd.MultiIndex): raise ValueError(f"coordinate {dim} has no MultiIndex") new_index = index.reorder_levels(order) variables[dim] = IndexVariable(coord.dims, new_index) - indexes[dim] = new_index + indexes[dim] = PandasMultiIndex(new_index, dim) return self._replace(variables, indexes=indexes) @@ -3551,8 +3846,8 @@ def _stack_once(self, dims, new_dim): coord_names = set(self._coord_names) - set(dims) | {new_dim} - indexes = {k: v for k, v in self.indexes.items() if k not in dims} - indexes[new_dim] = idx + indexes = {k: v for k, v in self.xindexes.items() if k not in dims} + indexes[new_dim] = PandasMultiIndex(idx, new_dim) return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes @@ -3586,7 +3881,7 @@ def stack( stacked : Dataset Dataset with stacked data. - See also + See Also -------- Dataset.unstack """ @@ -3599,8 +3894,8 @@ def stack( def to_stacked_array( self, new_dim: Hashable, - sample_dims: Sequence[Hashable], - variable_dim: str = "variable", + sample_dims: Collection, + variable_dim: Hashable = "variable", name: Hashable = None, ) -> "DataArray": """Combine variables of differing dimensionality into a DataArray @@ -3613,14 +3908,15 @@ def to_stacked_array( ---------- new_dim : hashable Name of the new stacked coordinate - sample_dims : sequence of hashable - Dimensions that **will not** be stacked. Each array in the dataset - must share these dimensions. For machine learning applications, - these define the dimensions over which samples are drawn. - variable_dim : str, optional + sample_dims : Collection of hashables + List of dimensions that **will not** be stacked. Each array in the + dataset must share these dimensions. For machine learning + applications, these define the dimensions over which samples are + drawn. + variable_dim : hashable, optional Name of the level in the stacked coordinate which corresponds to the variables. - name : str, optional + name : hashable, optional Name of the new data array. Returns @@ -3704,7 +4000,9 @@ def ensure_stackable(val): # coerce the levels of the MultiIndex to have the same type as the # input dimensions. This code is messy, so it might be better to just # input a dummy value for the singleton dimension. - idx = data_array.indexes[new_dim] + # TODO: benbovy - flexible indexes: update when MultIndex has its own + # class inheriting from xarray.Index + idx = data_array.xindexes[new_dim].to_pandas_index() levels = [idx.levels[0]] + [ level.astype(self[level.name].dtype) for level in idx.levels[1:] ] @@ -3721,7 +4019,7 @@ def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset": index = remove_unused_levels_categories(index) variables: Dict[Hashable, Variable] = {} - indexes = {k: v for k, v in self.indexes.items() if k != dim} + indexes = {k: v for k, v in self.xindexes.items() if k != dim} for name, var in self.variables.items(): if name != dim: @@ -3738,8 +4036,9 @@ def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset": variables[name] = var for name, lev in zip(index.names, index.levels): - variables[name] = IndexVariable(name, lev) - indexes[name] = lev + idx, idx_vars = PandasIndex.from_pandas_index(lev, name) + variables[name] = idx_vars[name] + indexes[name] = idx coord_names = set(self._coord_names) - {dim} | set(index.names) @@ -3766,7 +4065,7 @@ def _unstack_full_reindex( new_dim_sizes = [lev.size for lev in index.levels] variables: Dict[Hashable, Variable] = {} - indexes = {k: v for k, v in self.indexes.items() if k != dim} + indexes = {k: v for k, v in self.xindexes.items() if k != dim} for name, var in obj.variables.items(): if name != dim: @@ -3777,8 +4076,9 @@ def _unstack_full_reindex( variables[name] = var for name, lev in zip(new_dim_names, index.levels): - variables[name] = IndexVariable(name, lev) - indexes[name] = lev + idx, idx_vars = PandasIndex.from_pandas_index(lev, name) + variables[name] = idx_vars[name] + indexes[name] = idx coord_names = set(self._coord_names) - {dim} | set(new_dim_names) @@ -3815,7 +4115,7 @@ def unstack( unstacked : Dataset Dataset with unstacked data. - See also + See Also -------- Dataset.stack """ @@ -3832,7 +4132,7 @@ def unstack( missing_dims = [d for d in dims if d not in self.dims] if missing_dims: raise ValueError( - "Dataset does not contain the dimensions: %s" % missing_dims + f"Dataset does not contain the dimensions: {missing_dims}" ) non_multi_dims = [ @@ -3841,7 +4141,7 @@ def unstack( if non_multi_dims: raise ValueError( "cannot unstack dimensions that do not " - "have a MultiIndex: %s" % non_multi_dims + f"have a MultiIndex: {non_multi_dims}" ) result = self.copy(deep=False) @@ -3860,8 +4160,6 @@ def unstack( for v in self.variables.values() ) or sparse - # numpy full_like only added `shape` in 1.17 - or LooseVersion(np.__version__) < LooseVersion("1.17") # Until https://github.com/pydata/xarray/pull/4751 is resolved, # we check explicitly whether it's a numpy array. Once that is # resolved, explicitly exclude pint arrays. @@ -3883,6 +4181,9 @@ def unstack( def update(self, other: "CoercibleMapping") -> "Dataset": """Update this dataset's variables with those from another dataset. + Just like :py:meth:`dict.update` this is a in-place operation. + For a non-inplace version, see :py:meth:`Dataset.merge`. + Parameters ---------- other : Dataset or mapping @@ -3894,17 +4195,24 @@ def update(self, other: "CoercibleMapping") -> "Dataset": - mapping {var name: (dimension name, array-like)} - mapping {var name: (tuple of dimension names, array-like)} - Returns ------- updated : Dataset - Updated dataset. + Updated dataset. Note that since the update is in-place this is the input + dataset. + + It is deprecated since version 0.17 and scheduled to be removed in 0.21. Raises ------ ValueError If any dimensions would have inconsistent sizes in the updated dataset. + + See Also + -------- + Dataset.assign + Dataset.merge """ merge_result = dataset_update_method(self, other) return self._replace(inplace=True, **merge_result._asdict()) @@ -3916,6 +4224,7 @@ def merge( compat: str = "no_conflicts", join: str = "outer", fill_value: Any = dtypes.NA, + combine_attrs: str = "override", ) -> "Dataset": """Merge the arrays of two datasets into a single dataset. @@ -3944,7 +4253,6 @@ def merge( - 'no_conflicts': only values which are not null in both datasets must be equal. The returned dataset then contains the combination of all non-null values. - join : {"outer", "inner", "left", "right", "exact"}, optional Method for joining ``self`` and ``other`` along shared dimensions: @@ -3956,6 +4264,18 @@ def merge( fill_value : scalar or dict-like, optional Value to use for newly missing values. If a dict-like, maps variable names (including coordinates) to fill values. + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ + "override"}, default: "override" + String indicating how to combine attrs of the objects being merged: + + - "drop": empty attrs on returned Dataset. + - "identical": all attrs must be the same on every object. + - "no_conflicts": attrs from all objects are combined, any that have + the same name must also have the same value. + - "drop_conflicts": attrs from all objects are combined, any that have + the same name but different values are dropped. + - "override": skip comparing and copy attrs from the first dataset to + the result. Returns ------- @@ -3966,6 +4286,10 @@ def merge( ------ MergeError If any variables conflict (see ``compat``). + + See Also + -------- + Dataset.update """ other = other.to_dataset() if isinstance(other, xr.DataArray) else other merge_result = dataset_merge_method( @@ -3975,6 +4299,7 @@ def merge( compat=compat, join=join, fill_value=fill_value, + combine_attrs=combine_attrs, ) return self._replace(**merge_result._asdict()) @@ -4019,7 +4344,7 @@ def drop_vars( variables = {k: v for k, v in self._variables.items() if k not in names} coord_names = {k for k in self._coord_names if k in variables} - indexes = {k: v for k, v in self.indexes.items() if k not in names} + indexes = {k: v for k, v in self.xindexes.items() if k not in names} return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes ) @@ -4137,7 +4462,7 @@ def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs): try: index = self.get_index(dim) except KeyError: - raise ValueError("dimension %r does not have coordinate labels" % dim) + raise ValueError(f"dimension {dim!r} does not have coordinate labels") new_index = index.drop(labels_for_dim, errors=errors) ds = ds.loc[{dim: new_index}] return ds @@ -4216,21 +4541,16 @@ def drop_dims( ---------- drop_dims : hashable or iterable of hashable Dimension or dimensions to drop. - errors : {"raise", "ignore"}, optional - If 'raise' (default), raises a ValueError error if any of the + errors : {"raise", "ignore"}, default: "raise" + If 'raise', raises a ValueError error if any of the dimensions passed are not in the dataset. If 'ignore', any given - labels that are in the dataset are dropped and no error is raised. + dimensions that are in the dataset are dropped and no error is raised. Returns ------- obj : Dataset The dataset without the given dimensions (or any variables - containing those dimensions) - errors : {"raise", "ignore"}, optional - If 'raise' (default), raises a ValueError error if - any of the dimensions passed are not - in the dataset. If 'ignore', any given dimensions that are in the - dataset are dropped and no error is raised. + containing those dimensions). """ if errors not in ["raise", "ignore"]: raise ValueError('errors must be either "raise" or "ignore"') @@ -4244,13 +4564,17 @@ def drop_dims( missing_dims = drop_dims - set(self.dims) if missing_dims: raise ValueError( - "Dataset does not contain the dimensions: %s" % missing_dims + f"Dataset does not contain the dimensions: {missing_dims}" ) drop_vars = {k for k, v in self._variables.items() if set(v.dims) & drop_dims} return self.drop_vars(drop_vars) - def transpose(self, *dims: Hashable) -> "Dataset": + def transpose( + self, + *dims: Hashable, + missing_dims: str = "raise", + ) -> "Dataset": """Return a new Dataset object with all array dimensions transposed. Although the order of dimensions on each array will change, the dataset @@ -4261,6 +4585,12 @@ def transpose(self, *dims: Hashable) -> "Dataset": *dims : hashable, optional By default, reverse the dimensions on each array. Otherwise, reorder the dimensions to this order. + missing_dims : {"raise", "warn", "ignore"}, default: "raise" + What to do if dimensions that should be selected from are not present in the + Dataset: + - "raise": raise an exception + - "warn": raise a warning, and ignore the missing dimensions + - "ignore": ignore the missing dimensions Returns ------- @@ -4279,12 +4609,10 @@ def transpose(self, *dims: Hashable) -> "Dataset": numpy.transpose DataArray.transpose """ - if dims: - if set(dims) ^ set(self.dims) and ... not in dims: - raise ValueError( - "arguments to transpose (%s) must be " - "permuted dataset dimensions (%s)" % (dims, tuple(self.dims)) - ) + # Use infix_dims to check once for missing dimensions + if len(dims) != 0: + _ = list(infix_dims(dims, self.dims, missing_dims)) + ds = self.copy() for name, var in self._variables.items(): var_dims = tuple(dim for dim in dims if dim in (var.dims + (...,))) @@ -4324,19 +4652,19 @@ def dropna( # depending on the order of the supplied axes. if dim not in self.dims: - raise ValueError("%s must be a single dataset dimension" % dim) + raise ValueError(f"{dim} must be a single dataset dimension") if subset is None: subset = iter(self.data_vars) count = np.zeros(self.dims[dim], dtype=np.int64) - size = 0 + size = np.int_(0) # for type checking for k in subset: array = self._variables[k] if dim in array.dims: dims = [d for d in array.dims if d != dim] - count += np.asarray(array.count(dims)) # type: ignore + count += np.asarray(array.count(dims)) # type: ignore[attr-defined] size += np.prod([self.dims[d] for d in dims]) if thresh is not None: @@ -4346,7 +4674,7 @@ def dropna( elif how == "all": mask = count > 0 elif how is not None: - raise ValueError("invalid how option: %s" % how) + raise ValueError(f"invalid how option: {how}") else: raise TypeError("must specify how or thresh") @@ -4375,9 +4703,6 @@ def fillna(self, value: Any) -> "Dataset": Examples -------- - - >>> import numpy as np - >>> import xarray as xr >>> ds = xr.Dataset( ... { ... "A": ("x", [np.nan, 2, np.nan, 0]), @@ -4452,7 +4777,6 @@ def interpolate_na( ---------- dim : str Specifies the dimension along which to interpolate. - method : str, optional String indicating which method to use for interpolation: @@ -4464,7 +4788,6 @@ def interpolate_na( provided. - 'barycentric', 'krog', 'pchip', 'spline', 'akima': use their respective :py:class:`scipy.interpolate` classes. - use_coordinate : bool, str, default: True Specifies which index to use as the x values in the interpolation formulated as `y = f(x)`. If False, values are treated as if @@ -4499,7 +4822,7 @@ def interpolate_na( * x (x) int64 0 1 2 3 4 5 6 7 8 The gap lengths are 3-0 = 3; 6-3 = 3; and 8-6 = 2 respectively - kwargs : dict, optional + **kwargs : dict, optional parameters passed verbatim to the underlying interpolation function Returns @@ -4507,7 +4830,7 @@ def interpolate_na( interpolated: Dataset Filled in Dataset. - See also + See Also -------- numpy.interp scipy.interpolate @@ -4584,7 +4907,8 @@ def ffill(self, dim: Hashable, limit: int = None) -> "Dataset": The maximum number of consecutive NaN values to forward fill. In other words, if there is a gap with more than this number of consecutive NaNs, it will only be partially filled. Must be greater - than 0 or None for no limit. + than 0 or None for no limit. Must be None or greater than or equal + to axis length if filling along chunked axes (dimensions). Returns ------- @@ -4609,7 +4933,8 @@ def bfill(self, dim: Hashable, limit: int = None) -> "Dataset": The maximum number of consecutive NaN values to backward fill. In other words, if there is a gap with more than this number of consecutive NaNs, it will only be partially filled. Must be greater - than 0 or None for no limit. + than 0 or None for no limit. Must be None or greater than or equal + to axis length if filling along chunked axes (dimensions). Returns ------- @@ -4678,6 +5003,12 @@ def reduce( Dataset with this object's DataArrays replaced with new DataArrays of summarized data and the indicated dimension(s) removed. """ + if "axis" in kwargs: + raise ValueError( + "passing 'axis' to Dataset reduce methods is ambiguous." + " Please use 'dim' instead." + ) + if dim is None or dim is ...: dims = set(self.dims) elif isinstance(dim, str) or not isinstance(dim, Iterable): @@ -4688,7 +5019,7 @@ def reduce( missing_dimensions = [d for d in dims if d not in self.dims] if missing_dimensions: raise ValueError( - "Dataset does not contain the dimensions: %s" % missing_dimensions + f"Dataset does not contain the dimensions: {missing_dimensions}" ) if keep_attrs is None: @@ -4702,7 +5033,10 @@ def reduce( variables[name] = var else: if ( - not numeric_only + # Some reduction functions (e.g. std, var) need to run on variables + # that don't have the reduce dims: PR5393 + not reduce_dims + or not numeric_only or np.issubdtype(var.dtype, np.number) or (var.dtype == np.bool_) ): @@ -4714,7 +5048,7 @@ def reduce( # prefer to aggregate over axis=None rather than # axis=(0, 1) if they will be equivalent, because # the former is often more efficient - reduce_dims = None # type: ignore + reduce_dims = None # type: ignore[assignment] variables[name] = var.reduce( func, dim=reduce_dims, @@ -4724,7 +5058,7 @@ def reduce( ) coord_names = {k for k in self.coords if k in variables} - indexes = {k: v for k, v in self.indexes.items() if k in variables} + indexes = {k: v for k, v in self.xindexes.items() if k in variables} attrs = self.attrs if keep_attrs else None return self._replace_with_new_dims( variables, coord_names=coord_names, attrs=attrs, indexes=indexes @@ -4967,6 +5301,27 @@ def _normalize_dim_order( return ordered_dims + def to_pandas(self) -> Union[pd.Series, pd.DataFrame]: + """Convert this dataset into a pandas object without changing the number of dimensions. + + The type of the returned object depends on the number of Dataset + dimensions: + + * 0D -> `pandas.Series` + * 1D -> `pandas.DataFrame` + + Only works for Datasets with 1 or fewer dimensions. + """ + if len(self.dims) == 0: + return pd.Series({k: v.item() for k, v in self.items()}) + if len(self.dims) == 1: + return self.to_dataframe() + raise ValueError( + "cannot convert Datasets with %s dimensions into " + "pandas objects without changing the number of dimensions. " + "Please use Dataset.to_dataframe() instead." % len(self.dims) + ) + def _to_dataframe(self, ordered_dims: Mapping[Hashable, int]): columns = [k for k in self.variables if k not in self.dims] data = [ @@ -5014,7 +5369,7 @@ def _set_sparse_data_from_dataframe( if isinstance(idx, pd.MultiIndex): coords = np.stack([np.asarray(code) for code in idx.codes], axis=0) - is_sorted = idx.is_lexsorted() + is_sorted = idx.is_monotonic_increasing shape = tuple(lev.size for lev in idx.levels) else: coords = np.arange(idx.size).reshape(1, -1) @@ -5101,7 +5456,7 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> "Datas ------- New Dataset. - See also + See Also -------- xarray.DataArray.from_series pandas.DataFrame.to_xarray @@ -5232,7 +5587,7 @@ def to_dict(self, data=True): Whether to include the actual data in the dictionary. When set to False, returns just the schema. - See also + See Also -------- Dataset.from_dict """ @@ -5320,78 +5675,62 @@ def from_dict(cls, d): return obj - @staticmethod - def _unary_op(f): - @functools.wraps(f) - def func(self, *args, **kwargs): - variables = {} - keep_attrs = kwargs.pop("keep_attrs", None) - if keep_attrs is None: - keep_attrs = _get_keep_attrs(default=True) - for k, v in self._variables.items(): - if k in self._coord_names: - variables[k] = v - else: - variables[k] = f(v, *args, **kwargs) - if keep_attrs: - variables[k].attrs = v._attrs - attrs = self._attrs if keep_attrs else None - return self._replace_with_new_dims(variables, attrs=attrs) - - return func - - @staticmethod - def _binary_op(f, reflexive=False, join=None): - @functools.wraps(f) - def func(self, other): - from .dataarray import DataArray - - if isinstance(other, groupby.GroupBy): - return NotImplemented - align_type = OPTIONS["arithmetic_join"] if join is None else join - if isinstance(other, (DataArray, Dataset)): - self, other = align(self, other, join=align_type, copy=False) - g = f if not reflexive else lambda x, y: f(y, x) - ds = self._calculate_binary_op(g, other, join=align_type) - return ds - - return func - - @staticmethod - def _inplace_binary_op(f): - @functools.wraps(f) - def func(self, other): - from .dataarray import DataArray - - if isinstance(other, groupby.GroupBy): - raise TypeError( - "in-place operations between a Dataset and " - "a grouped object are not permitted" - ) - # we don't actually modify arrays in-place with in-place Dataset - # arithmetic -- this lets us automatically align things - if isinstance(other, (DataArray, Dataset)): - other = other.reindex_like(self, copy=False) - g = ops.inplace_to_noninplace_op(f) - ds = self._calculate_binary_op(g, other, inplace=True) - self._replace_with_new_dims( - ds._variables, - ds._coord_names, - attrs=ds._attrs, - indexes=ds._indexes, - inplace=True, - ) - return self + def _unary_op(self, f, *args, **kwargs): + variables = {} + keep_attrs = kwargs.pop("keep_attrs", None) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + for k, v in self._variables.items(): + if k in self._coord_names: + variables[k] = v + else: + variables[k] = f(v, *args, **kwargs) + if keep_attrs: + variables[k].attrs = v._attrs + attrs = self._attrs if keep_attrs else None + return self._replace_with_new_dims(variables, attrs=attrs) + + def _binary_op(self, other, f, reflexive=False, join=None): + from .dataarray import DataArray - return func + if isinstance(other, groupby.GroupBy): + return NotImplemented + align_type = OPTIONS["arithmetic_join"] if join is None else join + if isinstance(other, (DataArray, Dataset)): + self, other = align(self, other, join=align_type, copy=False) + g = f if not reflexive else lambda x, y: f(y, x) + ds = self._calculate_binary_op(g, other, join=align_type) + return ds + + def _inplace_binary_op(self, other, f): + from .dataarray import DataArray + + if isinstance(other, groupby.GroupBy): + raise TypeError( + "in-place operations between a Dataset and " + "a grouped object are not permitted" + ) + # we don't actually modify arrays in-place with in-place Dataset + # arithmetic -- this lets us automatically align things + if isinstance(other, (DataArray, Dataset)): + other = other.reindex_like(self, copy=False) + g = ops.inplace_to_noninplace_op(f) + ds = self._calculate_binary_op(g, other, inplace=True) + self._replace_with_new_dims( + ds._variables, + ds._coord_names, + attrs=ds._attrs, + indexes=ds._indexes, + inplace=True, + ) + return self def _calculate_binary_op(self, f, other, join="inner", inplace=False): def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars): if inplace and set(lhs_data_vars) != set(rhs_data_vars): raise ValueError( "datasets must have the same data variables " - "for in-place arithmetic operations: %s, %s" - % (list(lhs_data_vars), list(rhs_data_vars)) + f"for in-place arithmetic operations: {list(lhs_data_vars)}, {list(rhs_data_vars)}" ) dest_vars = {} @@ -5454,10 +5793,10 @@ def diff(self, dim, n=1, label="upper"): difference : same type as caller The n-th order finite difference of this object. - .. note:: - - `n` matches numpy's behavior and is different from pandas' first - argument named `periods`. + Notes + ----- + `n` matches numpy's behavior and is different from pandas' first argument named + `periods`. Examples -------- @@ -5507,9 +5846,15 @@ def diff(self, dim, n=1, label="upper"): else: variables[name] = var - indexes = dict(self.indexes) + indexes = dict(self.xindexes) if dim in indexes: - indexes[dim] = indexes[dim][kwargs_new[dim]] + if isinstance(indexes[dim], PandasIndex): + # maybe optimize? (pandas index already indexed above with var.isel) + new_index = indexes[dim].index[kwargs_new[dim]] + if isinstance(new_index, pd.MultiIndex): + indexes[dim] = PandasMultiIndex(new_index, dim) + else: + indexes[dim] = PandasIndex(new_index, dim) difference = self._replace_with_new_dims(variables, indexes=indexes) @@ -5543,13 +5888,12 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): Dataset with the same coordinates and attributes but shifted data variables. - See also + See Also -------- roll Examples -------- - >>> ds = xr.Dataset({"foo": ("x", list("abcde"))}) >>> ds.shift(x=2) @@ -5561,7 +5905,7 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): shifts = either_dict_or_kwargs(shifts, shifts_kwargs, "shift") invalid = [k for k in shifts if k not in self.dims] if invalid: - raise ValueError("dimensions %r do not exist" % invalid) + raise ValueError(f"dimensions {invalid!r} do not exist") variables = {} for name, var in self.variables.items(): @@ -5588,7 +5932,6 @@ def roll(self, shifts=None, roll_coords=None, **shifts_kwargs): Parameters ---------- - shifts : dict, optional A dict with keys matching dimensions and values given by integers to rotate each of the given dimensions. Positive @@ -5607,13 +5950,12 @@ def roll(self, shifts=None, roll_coords=None, **shifts_kwargs): Dataset with the same coordinates and attributes but rolled variables. - See also + See Also -------- shift Examples -------- - >>> ds = xr.Dataset({"foo": ("x", list("abcde"))}) >>> ds.roll(x=2) @@ -5625,7 +5967,7 @@ def roll(self, shifts=None, roll_coords=None, **shifts_kwargs): shifts = either_dict_or_kwargs(shifts, shifts_kwargs, "roll") invalid = [k for k in shifts if k not in self.dims] if invalid: - raise ValueError("dimensions %r do not exist" % invalid) + raise ValueError(f"dimensions {invalid!r} do not exist") if roll_coords is None: warnings.warn( @@ -5649,14 +5991,14 @@ def roll(self, shifts=None, roll_coords=None, **shifts_kwargs): if roll_coords: indexes = {} - for k, v in self.indexes.items(): + for k, v in self.xindexes.items(): (dim,) = self.variables[k].dims if dim in shifts: indexes[k] = roll_index(v, shifts[dim]) else: indexes[k] = v else: - indexes = dict(self.indexes) + indexes = dict(self.xindexes) return self._replace(variables, indexes=indexes) @@ -5680,10 +6022,10 @@ def sortby(self, variables, ascending=True): Parameters ---------- - variables: str, DataArray, or list of str or DataArray + variables : str, DataArray, or list of str or DataArray 1D DataArray objects or name(s) of 1D variable(s) in coords/data_vars whose values are used to sort the dataset. - ascending: bool, optional + ascending : bool, optional Whether to sort by ascending or descending order. Returns @@ -5771,7 +6113,6 @@ def quantile( Examples -------- - >>> ds = xr.Dataset( ... {"a": (("x", "y"), [[0.7, 4.2, 9.4, 1.5], [6.5, 7.3, 2.6, 1.9]])}, ... coords={"x": [7, 9], "y": [1, 1.5, 2, 2.5]}, @@ -5850,7 +6191,7 @@ def quantile( # construct the new dataset coord_names = {k for k in self.coords if k in variables} - indexes = {k: v for k, v in self.indexes.items() if k in variables} + indexes = {k: v for k, v in self.xindexes.items() if k in variables} if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) attrs = self.attrs if keep_attrs else None @@ -5887,8 +6228,14 @@ def rank(self, dim, pct=False, keep_attrs=None): ranked : Dataset Variables that do not depend on `dim` are dropped. """ + if not OPTIONS["use_bottleneck"]: + raise RuntimeError( + "rank requires bottleneck to be enabled." + " Call `xr.set_options(use_bottleneck=True)` to enable it." + ) + if dim not in self.dims: - raise ValueError("Dataset does not contain the dimension: %s" % dim) + raise ValueError(f"Dataset does not contain the dimension: {dim}") variables = {} for name, var in self.variables.items(): @@ -5956,7 +6303,10 @@ def differentiate(self, coord, edge_order=1, datetime_unit=None): if _contains_datetime_like_objects(v): v = v._to_numeric(datetime_unit=datetime_unit) grad = duck_array_ops.gradient( - v.data, coord_var, edge_order=edge_order, axis=v.get_axis_num(dim) + v.data, + coord_var.data, + edge_order=edge_order, + axis=v.get_axis_num(dim), ) variables[k] = Variable(v.dims, grad) else: @@ -5964,7 +6314,9 @@ def differentiate(self, coord, edge_order=1, datetime_unit=None): return self._replace(variables) def integrate( - self, coord: Union[Hashable, Sequence[Hashable]], datetime_unit: str = None + self, + coord: Union[Hashable, Sequence[Hashable]], + datetime_unit: str = None, ) -> "Dataset": """Integrate along the given coordinate using the trapezoidal rule. @@ -5974,9 +6326,9 @@ def integrate( Parameters ---------- - coord: hashable, or a sequence of hashable + coord : hashable, or sequence of hashable Coordinate(s) used for the integration. - datetime_unit: {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', \ + datetime_unit : {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', \ 'ps', 'fs', 'as'}, optional Specify the unit if datetime coordinate is used. @@ -5987,7 +6339,7 @@ def integrate( See also -------- DataArray.integrate - numpy.trapz: corresponding numpy function + numpy.trapz : corresponding numpy function Examples -------- @@ -6024,7 +6376,7 @@ def integrate( result = result._integrate_one(c, datetime_unit=datetime_unit) return result - def _integrate_one(self, coord, datetime_unit=None): + def _integrate_one(self, coord, datetime_unit=None, cumulative=False): from .variable import Variable if coord not in self.variables and coord not in self.dims: @@ -6051,26 +6403,107 @@ def _integrate_one(self, coord, datetime_unit=None): coord_names = set() for k, v in self.variables.items(): if k in self.coords: - if dim not in v.dims: + if dim not in v.dims or cumulative: variables[k] = v coord_names.add(k) else: if k in self.data_vars and dim in v.dims: if _contains_datetime_like_objects(v): v = datetime_to_numeric(v, datetime_unit=datetime_unit) - integ = duck_array_ops.trapz( - v.data, coord_var.data, axis=v.get_axis_num(dim) - ) - v_dims = list(v.dims) - v_dims.remove(dim) + if cumulative: + integ = duck_array_ops.cumulative_trapezoid( + v.data, coord_var.data, axis=v.get_axis_num(dim) + ) + v_dims = v.dims + else: + integ = duck_array_ops.trapz( + v.data, coord_var.data, axis=v.get_axis_num(dim) + ) + v_dims = list(v.dims) + v_dims.remove(dim) variables[k] = Variable(v_dims, integ) else: variables[k] = v - indexes = {k: v for k, v in self.indexes.items() if k in variables} + indexes = {k: v for k, v in self.xindexes.items() if k in variables} return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes ) + def cumulative_integrate( + self, + coord: Union[Hashable, Sequence[Hashable]], + datetime_unit: str = None, + ) -> "Dataset": + """Integrate along the given coordinate using the trapezoidal rule. + + .. note:: + This feature is limited to simple cartesian geometry, i.e. coord + must be one dimensional. + + The first entry of the cumulative integral of each variable is always 0, in + order to keep the length of the dimension unchanged between input and + output. + + Parameters + ---------- + coord : hashable, or sequence of hashable + Coordinate(s) used for the integration. + datetime_unit : {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', \ + 'ps', 'fs', 'as'}, optional + Specify the unit if datetime coordinate is used. + + Returns + ------- + integrated : Dataset + + See also + -------- + DataArray.cumulative_integrate + scipy.integrate.cumulative_trapezoid : corresponding scipy function + + Examples + -------- + >>> ds = xr.Dataset( + ... data_vars={"a": ("x", [5, 5, 6, 6]), "b": ("x", [1, 2, 1, 0])}, + ... coords={"x": [0, 1, 2, 3], "y": ("x", [1, 7, 3, 5])}, + ... ) + >>> ds + + Dimensions: (x: 4) + Coordinates: + * x (x) int64 0 1 2 3 + y (x) int64 1 7 3 5 + Data variables: + a (x) int64 5 5 6 6 + b (x) int64 1 2 1 0 + >>> ds.cumulative_integrate("x") + + Dimensions: (x: 4) + Coordinates: + * x (x) int64 0 1 2 3 + y (x) int64 1 7 3 5 + Data variables: + a (x) float64 0.0 5.0 10.5 16.5 + b (x) float64 0.0 1.5 3.0 3.5 + >>> ds.cumulative_integrate("y") + + Dimensions: (x: 4) + Coordinates: + * x (x) int64 0 1 2 3 + y (x) int64 1 7 3 5 + Data variables: + a (x) float64 0.0 30.0 8.0 20.0 + b (x) float64 0.0 9.0 3.0 4.0 + """ + if not isinstance(coord, (list, tuple)): + coord = (coord,) + result = self + for c in coord: + result = result._integrate_one( + c, datetime_unit=datetime_unit, cumulative=True + ) + return result + @property def real(self): return self.map(lambda x: x.real, keep_attrs=True) @@ -6110,7 +6543,6 @@ def filter_by_attrs(self, **kwargs): Examples -------- - >>> # Create an example dataset: >>> temp = 15 + 8 * np.random.randn(2, 2, 3) >>> precip = 10 * np.random.rand(2, 2, 3) >>> lon = [[-99.83, -99.32], [-99.79, -99.23]] @@ -6118,22 +6550,25 @@ def filter_by_attrs(self, **kwargs): >>> dims = ["x", "y", "time"] >>> temp_attr = dict(standard_name="air_potential_temperature") >>> precip_attr = dict(standard_name="convective_precipitation_flux") + >>> ds = xr.Dataset( - ... { - ... "temperature": (dims, temp, temp_attr), - ... "precipitation": (dims, precip, precip_attr), - ... }, - ... coords={ - ... "lon": (["x", "y"], lon), - ... "lat": (["x", "y"], lat), - ... "time": pd.date_range("2014-09-06", periods=3), - ... "reference_time": pd.Timestamp("2014-09-05"), - ... }, + ... dict( + ... temperature=(dims, temp, temp_attr), + ... precipitation=(dims, precip, precip_attr), + ... ), + ... coords=dict( + ... lon=(["x", "y"], lon), + ... lat=(["x", "y"], lat), + ... time=pd.date_range("2014-09-06", periods=3), + ... reference_time=pd.Timestamp("2014-09-05"), + ... ), ... ) - >>> # Get variables matching a specific standard_name. + + Get variables matching a specific standard_name: + >>> ds.filter_by_attrs(standard_name="convective_precipitation_flux") - Dimensions: (time: 3, x: 2, y: 2) + Dimensions: (x: 2, y: 2, time: 3) Coordinates: lon (x, y) float64 -99.83 -99.32 -99.79 -99.23 lat (x, y) float64 42.25 42.21 42.63 42.59 @@ -6142,11 +6577,13 @@ def filter_by_attrs(self, **kwargs): Dimensions without coordinates: x, y Data variables: precipitation (x, y, time) float64 5.68 9.256 0.7104 ... 7.992 4.615 7.805 - >>> # Get all variables that have a standard_name attribute. + + Get all variables that have a standard_name attribute: + >>> standard_name = lambda v: v is not None >>> ds.filter_by_attrs(standard_name=standard_name) - Dimensions: (time: 3, x: 2, y: 2) + Dimensions: (x: 2, y: 2, time: 3) Coordinates: lon (x, y) float64 -99.83 -99.32 -99.79 -99.23 lat (x, y) float64 42.25 42.21 42.63 42.59 @@ -6177,46 +6614,14 @@ def unify_chunks(self) -> "Dataset": Returns ------- - Dataset with consistent chunk sizes for all dask-array variables See Also -------- - dask.array.core.unify_chunks """ - try: - self.chunks - except ValueError: # "inconsistent chunks" - pass - else: - # No variables with dask backend, or all chunks are already aligned - return self.copy() - - # import dask is placed after the quick exit test above to allow - # running this method if dask isn't installed and there are no chunks - import dask.array - - ds = self.copy() - - dims_pos_map = {dim: index for index, dim in enumerate(ds.dims)} - - dask_array_names = [] - dask_unify_args = [] - for name, variable in ds.variables.items(): - if isinstance(variable.data, dask.array.Array): - dims_tuple = [dims_pos_map[dim] for dim in variable.dims] - dask_array_names.append(name) - dask_unify_args.append(variable.data) - dask_unify_args.append(dims_tuple) - - _, rechunked_arrays = dask.array.core.unify_chunks(*dask_unify_args) - - for name, new_array in zip(dask_array_names, rechunked_arrays): - ds.variables[name]._data = new_array - - return ds + return unify_chunks(self)[0] def map_blocks( self, @@ -6257,7 +6662,6 @@ def map_blocks( When provided, ``attrs`` on variables in `template` are copied over to the result. Any ``attrs`` set by ``func`` will be ignored. - Returns ------- A single DataArray or Dataset with dask backend, reassembled from the outputs of the @@ -6266,20 +6670,19 @@ def map_blocks( Notes ----- This function is designed for when ``func`` needs to manipulate a whole xarray object - subset to each block. In the more common case where ``func`` can work on numpy arrays, it is - recommended to use ``apply_ufunc``. + subset to each block. Each block is loaded into memory. In the more common case where + ``func`` can work on numpy arrays, it is recommended to use ``apply_ufunc``. If none of the variables in this object is backed by dask arrays, calling this function is equivalent to calling ``func(obj, *args, **kwargs)``. See Also -------- - dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks, + dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks xarray.DataArray.map_blocks Examples -------- - Calculate an anomaly from climatology using ``.groupby()``. Using ``xr.map_blocks()`` allows for parallel operations with knowledge of ``xarray``, its indices, and its methods like ``.groupby()``. @@ -6365,7 +6768,6 @@ def polyfit( Whether to return to the covariance matrix in addition to the coefficients. The matrix is not scaled if `cov='unscaled'`. - Returns ------- polyfit_results : Dataset @@ -6390,9 +6792,11 @@ def polyfit( The rank of the coefficient matrix in the least-squares fit is deficient. The warning is not raised with in-memory (not dask) data and `full=True`. - See also + See Also -------- numpy.polyfit + numpy.polyval + xarray.polyval """ variables = {} skipna_da = skipna @@ -6403,7 +6807,9 @@ def polyfit( lhs = np.vander(x, order) if rcond is None: - rcond = x.shape[0] * np.core.finfo(x.dtype).eps + rcond = ( + x.shape[0] * np.core.finfo(x.dtype).eps # type: ignore[attr-defined] + ) # Weights: if w is not None: @@ -6447,7 +6853,7 @@ def polyfit( # deficient ranks nor does it output the "full" info (issue dask/dask#6516) skipna_da = True elif skipna is None: - skipna_da = np.any(da.isnull()) + skipna_da = bool(np.any(da.isnull())) dims_to_stack = [dimname for dimname in da.dims if dimname != dim] stacked_coords: Dict[Hashable, DataArray] = {} @@ -6553,36 +6959,26 @@ def pad( mode : str, default: "constant" One of the following string values (taken from numpy docs). - 'constant' (default) - Pads with a constant value. - 'edge' - Pads with the edge values of array. - 'linear_ramp' - Pads with the linear ramp between end_value and the - array edge value. - 'maximum' - Pads with the maximum value of all or part of the - vector along each axis. - 'mean' - Pads with the mean value of all or part of the - vector along each axis. - 'median' - Pads with the median value of all or part of the - vector along each axis. - 'minimum' - Pads with the minimum value of all or part of the - vector along each axis. - 'reflect' - Pads with the reflection of the vector mirrored on - the first and last values of the vector along each - axis. - 'symmetric' - Pads with the reflection of the vector mirrored - along the edge of the array. - 'wrap' - Pads with the wrap of the vector along the axis. - The first values are used to pad the end and the - end values are used to pad the beginning. + - constant: Pads with a constant value. + - edge: Pads with the edge values of array. + - linear_ramp: Pads with the linear ramp between end_value and the + array edge value. + - maximum: Pads with the maximum value of all or part of the + vector along each axis. + - mean: Pads with the mean value of all or part of the + vector along each axis. + - median: Pads with the median value of all or part of the + vector along each axis. + - minimum: Pads with the minimum value of all or part of the + vector along each axis. + - reflect: Pads with the reflection of the vector mirrored on + the first and last values of the vector along each axis. + - symmetric: Pads with the reflection of the vector mirrored + along the edge of the array. + - wrap: Pads with the wrap of the vector along the axis. + The first values are used to pad the end and the + end values are used to pad the beginning. + stat_length : int, tuple or mapping of hashable to tuple, default: None Used in 'maximum', 'mean', 'median', and 'minimum'. Number of values at edge of each axis used to calculate the statistic value. @@ -6627,7 +7023,7 @@ def pad( padded : Dataset Dataset with the padded coordinates and data. - See also + See Also -------- Dataset.shift, Dataset.roll, Dataset.bfill, Dataset.ffill, numpy.pad, dask.array.pad @@ -6639,7 +7035,6 @@ def pad( Examples -------- - >>> ds = xr.Dataset({"foo": ("x", range(5))}) >>> ds.pad(x=(1, 2)) @@ -6680,7 +7075,7 @@ def pad( variables[name] = var.pad( pad_width=var_pad_width, mode=coord_pad_mode, - **coord_pad_options, # type: ignore + **coord_pad_options, # type: ignore[arg-type] ) return self._replace_vars_and_dims(variables) @@ -6728,13 +7123,12 @@ def idxmin( New `Dataset` object with `idxmin` applied to its data and the indicated dimension removed. - See also + See Also -------- DataArray.idxmin, Dataset.idxmax, Dataset.min, Dataset.argmin Examples -------- - >>> array1 = xr.DataArray( ... [0, 2, 1, 0, -2], dims="x", coords={"x": ["a", "b", "c", "d", "e"]} ... ) @@ -6826,13 +7220,12 @@ def idxmax( New `Dataset` object with `idxmax` applied to its data and the indicated dimension removed. - See also + See Also -------- DataArray.idxmax, Dataset.idxmin, Dataset.max, Dataset.argmax Examples -------- - >>> array1 = xr.DataArray( ... [0, 2, 1, 0, -2], dims="x", coords={"x": ["a", "b", "c", "d", "e"]} ... ) @@ -6881,7 +7274,7 @@ def idxmax( ) ) - def argmin(self, dim=None, axis=None, **kwargs): + def argmin(self, dim=None, **kwargs): """Indices of the minima of the member variables. If there are multiple minima, the indices of the first one found will be @@ -6895,9 +7288,6 @@ def argmin(self, dim=None, axis=None, **kwargs): this is deprecated, in future will be an error, since DataArray.argmin will return a dict with indices for all dimensions, which does not make sense for a Dataset. - axis : int, optional - Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments - can be supplied. keep_attrs : bool, optional If True, the attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be @@ -6912,31 +7302,28 @@ def argmin(self, dim=None, axis=None, **kwargs): ------- result : Dataset - See also + See Also -------- DataArray.argmin - """ - if dim is None and axis is None: + if dim is None: warnings.warn( - "Once the behaviour of DataArray.argmin() and Variable.argmin() with " - "neither dim nor axis argument changes to return a dict of indices of " - "each dimension, for consistency it will be an error to call " - "Dataset.argmin() with no argument, since we don't return a dict of " - "Datasets.", + "Once the behaviour of DataArray.argmin() and Variable.argmin() without " + "dim changes to return a dict of indices of each dimension, for " + "consistency it will be an error to call Dataset.argmin() with no argument," + "since we don't return a dict of Datasets.", DeprecationWarning, stacklevel=2, ) if ( dim is None - or axis is not None or (not isinstance(dim, Sequence) and dim is not ...) or isinstance(dim, str) ): # Return int index if single dimension is passed, and is not part of a # sequence argmin_func = getattr(duck_array_ops, "argmin") - return self.reduce(argmin_func, dim=dim, axis=axis, **kwargs) + return self.reduce(argmin_func, dim=dim, **kwargs) else: raise ValueError( "When dim is a sequence or ..., DataArray.argmin() returns a dict. " @@ -6944,7 +7331,7 @@ def argmin(self, dim=None, axis=None, **kwargs): "Dataset.argmin() with a sequence or ... for dim" ) - def argmax(self, dim=None, axis=None, **kwargs): + def argmax(self, dim=None, **kwargs): """Indices of the maxima of the member variables. If there are multiple maxima, the indices of the first one found will be @@ -6958,9 +7345,6 @@ def argmax(self, dim=None, axis=None, **kwargs): this is deprecated, in future will be an error, since DataArray.argmax will return a dict with indices for all dimensions, which does not make sense for a Dataset. - axis : int, optional - Axis over which to apply `argmax`. Only one of the 'dim' and 'axis' arguments - can be supplied. keep_attrs : bool, optional If True, the attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be @@ -6975,31 +7359,29 @@ def argmax(self, dim=None, axis=None, **kwargs): ------- result : Dataset - See also + See Also -------- DataArray.argmax """ - if dim is None and axis is None: + if dim is None: warnings.warn( - "Once the behaviour of DataArray.argmax() and Variable.argmax() with " - "neither dim nor axis argument changes to return a dict of indices of " - "each dimension, for consistency it will be an error to call " - "Dataset.argmax() with no argument, since we don't return a dict of " - "Datasets.", + "Once the behaviour of DataArray.argmin() and Variable.argmin() without " + "dim changes to return a dict of indices of each dimension, for " + "consistency it will be an error to call Dataset.argmin() with no argument," + "since we don't return a dict of Datasets.", DeprecationWarning, stacklevel=2, ) if ( dim is None - or axis is not None or (not isinstance(dim, Sequence) and dim is not ...) or isinstance(dim, str) ): # Return int index if single dimension is passed, and is not part of a # sequence argmax_func = getattr(duck_array_ops, "argmax") - return self.reduce(argmax_func, dim=dim, axis=axis, **kwargs) + return self.reduce(argmax_func, dim=dim, **kwargs) else: raise ValueError( "When dim is a sequence or ..., DataArray.argmin() returns a dict. " @@ -7007,5 +7389,273 @@ def argmax(self, dim=None, axis=None, **kwargs): "Dataset.argmin() with a sequence or ... for dim" ) + def query( + self, + queries: Mapping[Hashable, Any] = None, + parser: str = "pandas", + engine: str = None, + missing_dims: str = "raise", + **queries_kwargs: Any, + ) -> "Dataset": + """Return a new dataset with each array indexed along the specified + dimension(s), where the indexers are given as strings containing + Python expressions to be evaluated against the data variables in the + dataset. + + Parameters + ---------- + queries : dict, optional + A dict with keys matching dimensions and values given by strings + containing Python expressions to be evaluated against the data variables + in the dataset. The expressions will be evaluated using the pandas + eval() function, and can contain any valid Python expressions but cannot + contain any Python statements. + parser : {"pandas", "python"}, default: "pandas" + The parser to use to construct the syntax tree from the expression. + The default of 'pandas' parses code slightly different than standard + Python. Alternatively, you can parse an expression using the 'python' + parser to retain strict Python semantics. + engine : {"python", "numexpr", None}, default: None + The engine used to evaluate the expression. Supported engines are: + + - None: tries to use numexpr, falls back to python + - "numexpr": evaluates expressions using numexpr + - "python": performs operations as if you had eval’d in top level python + + missing_dims : {"raise", "warn", "ignore"}, default: "raise" + What to do if dimensions that should be selected from are not present in the + Dataset: + + - "raise": raise an exception + - "warning": raise a warning, and ignore the missing dimensions + - "ignore": ignore the missing dimensions + + **queries_kwargs : {dim: query, ...}, optional + The keyword arguments form of ``queries``. + One of queries or queries_kwargs must be provided. + + Returns + ------- + obj : Dataset + A new Dataset with the same contents as this dataset, except each + array and dimension is indexed by the results of the appropriate + queries. + + See Also + -------- + Dataset.isel + pandas.eval + + Examples + -------- + >>> a = np.arange(0, 5, 1) + >>> b = np.linspace(0, 1, 5) + >>> ds = xr.Dataset({"a": ("x", a), "b": ("x", b)}) + >>> ds + + Dimensions: (x: 5) + Dimensions without coordinates: x + Data variables: + a (x) int64 0 1 2 3 4 + b (x) float64 0.0 0.25 0.5 0.75 1.0 + >>> ds.query(x="a > 2") + + Dimensions: (x: 2) + Dimensions without coordinates: x + Data variables: + a (x) int64 3 4 + b (x) float64 0.75 1.0 + """ + + # allow queries to be given either as a dict or as kwargs + queries = either_dict_or_kwargs(queries, queries_kwargs, "query") + + # check queries + for dim, expr in queries.items(): + if not isinstance(expr, str): + msg = f"expr for dim {dim} must be a string to be evaluated, {type(expr)} given" + raise ValueError(msg) + + # evaluate the queries to create the indexers + indexers = { + dim: pd.eval(expr, resolvers=[self], parser=parser, engine=engine) + for dim, expr in queries.items() + } + + # apply the selection + return self.isel(indexers, missing_dims=missing_dims) + + def curvefit( + self, + coords: Union[Union[str, "DataArray"], Iterable[Union[str, "DataArray"]]], + func: Callable[..., Any], + reduce_dims: Union[Hashable, Iterable[Hashable]] = None, + skipna: bool = True, + p0: Dict[str, Any] = None, + bounds: Dict[str, Any] = None, + param_names: Sequence[str] = None, + kwargs: Dict[str, Any] = None, + ): + """ + Curve fitting optimization for arbitrary functions. + + Wraps `scipy.optimize.curve_fit` with `apply_ufunc`. + + Parameters + ---------- + coords : hashable, DataArray, or sequence of hashable or DataArray + Independent coordinate(s) over which to perform the curve fitting. Must share + at least one dimension with the calling object. When fitting multi-dimensional + functions, supply `coords` as a sequence in the same order as arguments in + `func`. To fit along existing dimensions of the calling object, `coords` can + also be specified as a str or sequence of strs. + func : callable + User specified function in the form `f(x, *params)` which returns a numpy + array of length `len(x)`. `params` are the fittable parameters which are optimized + by scipy curve_fit. `x` can also be specified as a sequence containing multiple + coordinates, e.g. `f((x0, x1), *params)`. + reduce_dims : hashable or sequence of hashable + Additional dimension(s) over which to aggregate while fitting. For example, + calling `ds.curvefit(coords='time', reduce_dims=['lat', 'lon'], ...)` will + aggregate all lat and lon points and fit the specified function along the + time dimension. + skipna : bool, optional + Whether to skip missing values when fitting. Default is True. + p0 : dict-like, optional + Optional dictionary of parameter names to initial guesses passed to the + `curve_fit` `p0` arg. If none or only some parameters are passed, the rest will + be assigned initial values following the default scipy behavior. + bounds : dict-like, optional + Optional dictionary of parameter names to bounding values passed to the + `curve_fit` `bounds` arg. If none or only some parameters are passed, the rest + will be unbounded following the default scipy behavior. + param_names : sequence of hashable, optional + Sequence of names for the fittable parameters of `func`. If not supplied, + this will be automatically determined by arguments of `func`. `param_names` + should be manually supplied when fitting a function that takes a variable + number of parameters. + **kwargs : optional + Additional keyword arguments to passed to scipy curve_fit. + + Returns + ------- + curvefit_results : Dataset + A single dataset which contains: + + [var]_curvefit_coefficients + The coefficients of the best fit. + [var]_curvefit_covariance + The covariance matrix of the coefficient estimates. + + See Also + -------- + Dataset.polyfit + scipy.optimize.curve_fit + """ + from scipy.optimize import curve_fit + + if p0 is None: + p0 = {} + if bounds is None: + bounds = {} + if kwargs is None: + kwargs = {} + + if not reduce_dims: + reduce_dims_ = [] + elif isinstance(reduce_dims, str) or not isinstance(reduce_dims, Iterable): + reduce_dims_ = [reduce_dims] + else: + reduce_dims_ = list(reduce_dims) + + if ( + isinstance(coords, str) + or isinstance(coords, xr.DataArray) + or not isinstance(coords, Iterable) + ): + coords = [coords] + coords_ = [self[coord] if isinstance(coord, str) else coord for coord in coords] + + # Determine whether any coords are dims on self + for coord in coords_: + reduce_dims_ += [c for c in self.dims if coord.equals(self[c])] + reduce_dims_ = list(set(reduce_dims_)) + preserved_dims = list(set(self.dims) - set(reduce_dims_)) + if not reduce_dims_: + raise ValueError( + "No arguments to `coords` were identified as a dimension on the calling " + "object, and no dims were supplied to `reduce_dims`. This would result " + "in fitting on scalar data." + ) + + # Broadcast all coords with each other + coords_ = xr.broadcast(*coords_) + coords_ = [ + coord.broadcast_like(self, exclude=preserved_dims) for coord in coords_ + ] -ops.inject_all_ops_and_reduce_methods(Dataset, array_only=False) + params, func_args = _get_func_args(func, param_names) + param_defaults, bounds_defaults = _initialize_curvefit_params( + params, p0, bounds, func_args + ) + n_params = len(params) + kwargs.setdefault("p0", [param_defaults[p] for p in params]) + kwargs.setdefault( + "bounds", + [ + [bounds_defaults[p][0] for p in params], + [bounds_defaults[p][1] for p in params], + ], + ) + + def _wrapper(Y, *coords_, **kwargs): + # Wrap curve_fit with raveled coordinates and pointwise NaN handling + x = np.vstack([c.ravel() for c in coords_]) + y = Y.ravel() + if skipna: + mask = np.all([np.any(~np.isnan(x), axis=0), ~np.isnan(y)], axis=0) + x = x[:, mask] + y = y[mask] + if not len(y): + popt = np.full([n_params], np.nan) + pcov = np.full([n_params, n_params], np.nan) + return popt, pcov + x = np.squeeze(x) + popt, pcov = curve_fit(func, x, y, **kwargs) + return popt, pcov + + result = xr.Dataset() + for name, da in self.data_vars.items(): + if name is xr.core.dataarray._THIS_ARRAY: + name = "" + else: + name = f"{str(name)}_" + + popt, pcov = xr.apply_ufunc( + _wrapper, + da, + *coords_, + vectorize=True, + dask="parallelized", + input_core_dims=[reduce_dims_ for d in range(len(coords_) + 1)], + output_core_dims=[["param"], ["cov_i", "cov_j"]], + dask_gufunc_kwargs={ + "output_sizes": { + "param": n_params, + "cov_i": n_params, + "cov_j": n_params, + }, + }, + output_dtypes=(np.float64, np.float64), + exclude_dims=set(reduce_dims_), + kwargs=kwargs, + ) + result[name + "curvefit_coefficients"] = popt + result[name + "curvefit_covariance"] = pcov + + result = result.assign_coords( + {"param": params, "cov_i": params, "cov_j": params} + ) + result.attrs = self.attrs.copy() + + return result diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index 167f00fa932..5f9349051b7 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -63,10 +63,7 @@ def maybe_promote(dtype): # Check np.timedelta64 before np.integer fill_value = np.timedelta64("NaT") elif np.issubdtype(dtype, np.integer): - if dtype.itemsize <= 2: - dtype = np.float32 - else: - dtype = np.float64 + dtype = np.float32 if dtype.itemsize <= 2 else np.float64 fill_value = np.nan elif np.issubdtype(dtype, np.complexfloating): fill_value = np.nan + np.nan * 1j @@ -78,7 +75,7 @@ def maybe_promote(dtype): return np.dtype(dtype), fill_value -NAT_TYPES = (np.datetime64("NaT"), np.timedelta64("NaT")) +NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype} def get_fill_value(dtype): @@ -96,40 +93,56 @@ def get_fill_value(dtype): return fill_value -def get_pos_infinity(dtype): +def get_pos_infinity(dtype, max_for_int=False): """Return an appropriate positive infinity for this dtype. Parameters ---------- dtype : np.dtype + max_for_int : bool + Return np.iinfo(dtype).max instead of np.inf Returns ------- fill_value : positive infinity value corresponding to this dtype. """ - if issubclass(dtype.type, (np.floating, np.integer)): + if issubclass(dtype.type, np.floating): return np.inf + if issubclass(dtype.type, np.integer): + if max_for_int: + return np.iinfo(dtype).max + else: + return np.inf + if issubclass(dtype.type, np.complexfloating): return np.inf + 1j * np.inf return INF -def get_neg_infinity(dtype): +def get_neg_infinity(dtype, min_for_int=False): """Return an appropriate positive infinity for this dtype. Parameters ---------- dtype : np.dtype + min_for_int : bool + Return np.iinfo(dtype).min instead of -np.inf Returns ------- fill_value : positive infinity value corresponding to this dtype. """ - if issubclass(dtype.type, (np.floating, np.integer)): + if issubclass(dtype.type, np.floating): return -np.inf + if issubclass(dtype.type, np.integer): + if min_for_int: + return np.iinfo(dtype).min + else: + return -np.inf + if issubclass(dtype.type, np.complexfloating): return -np.inf - 1j * np.inf diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index e6c3aae5bf8..579ac3a7b0f 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -7,7 +7,6 @@ import datetime import inspect import warnings -from distutils.version import LooseVersion from functools import partial import numpy as np @@ -20,6 +19,7 @@ dask_array_type, is_duck_dask_array, sparse_array_type, + sparse_version, ) from .utils import is_duck_array @@ -27,7 +27,7 @@ import dask.array as dask_array from dask.base import tokenize except ImportError: - dask_array = None # type: ignore + dask_array = None def _dask_or_eager_func( @@ -72,10 +72,6 @@ def fail_on_dask_array_input(values, msg=None, func_name=None): raise NotImplementedError(msg % func_name) -# switch to use dask.array / __array_function__ version when dask supports it: -# https://github.com/dask/dask/pull/4822 -moveaxis = npcompat.moveaxis - around = _dask_or_eager_func("around") isclose = _dask_or_eager_func("isclose") @@ -153,21 +149,32 @@ def trapz(y, x, axis): return sum(integrand, axis=axis, skipna=False) +def cumulative_trapezoid(y, x, axis): + if axis < 0: + axis = y.ndim + axis + x_sl1 = (slice(1, None),) + (None,) * (y.ndim - axis - 1) + x_sl2 = (slice(None, -1),) + (None,) * (y.ndim - axis - 1) + slice1 = (slice(None),) * axis + (slice(1, None),) + slice2 = (slice(None),) * axis + (slice(None, -1),) + dx = x[x_sl1] - x[x_sl2] + integrand = dx * 0.5 * (y[tuple(slice1)] + y[tuple(slice2)]) + + # Pad so that 'axis' has same length in result as it did in y + pads = [(1, 0) if i == axis else (0, 0) for i in range(y.ndim)] + integrand = pad(integrand, pads, mode="constant", constant_values=0.0) + + return cumsum(integrand, axis=axis, skipna=False) + + masked_invalid = _dask_or_eager_func( "masked_invalid", eager_module=np.ma, dask_module=getattr(dask_array, "ma", None) ) def astype(data, dtype, **kwargs): - try: - import sparse - except ImportError: - sparse = None - if ( - sparse is not None - and isinstance(data, sparse_array_type) - and LooseVersion(sparse.__version__) < LooseVersion("0.11.0") + isinstance(data, sparse_array_type) + and sparse_version < "0.11.0" and "casting" in kwargs ): warnings.warn( @@ -187,7 +194,7 @@ def asarray(data, xp=np): def as_shared_dtype(scalars_or_arrays): """Cast a arrays to a shared dtype using xarray's type promotion rules.""" - if any([isinstance(x, cupy_array_type) for x in scalars_or_arrays]): + if any(isinstance(x, cupy_array_type) for x in scalars_or_arrays): import cupy as cp arrays = [asarray(x, xp=cp) for x in scalars_or_arrays] @@ -310,13 +317,21 @@ def _ignore_warnings_if(condition): yield -def _create_nan_agg_method(name, dask_module=dask_array, coerce_strings=False): +def _create_nan_agg_method( + name, dask_module=dask_array, coerce_strings=False, invariant_0d=False +): from . import nanops def f(values, axis=None, skipna=None, **kwargs): if kwargs.pop("out", None) is not None: raise TypeError(f"`out` is not valid for {name}") + # The data is invariant in the case of 0d data, so do not + # change the data (and dtype) + # See https://github.com/pydata/xarray/issues/4885 + if invariant_0d and axis == (): + return values + values = asarray(values) if coerce_strings and values.dtype.kind in "SU": @@ -354,28 +369,30 @@ def f(values, axis=None, skipna=None, **kwargs): # See ops.inject_reduce_methods argmax = _create_nan_agg_method("argmax", coerce_strings=True) argmin = _create_nan_agg_method("argmin", coerce_strings=True) -max = _create_nan_agg_method("max", coerce_strings=True) -min = _create_nan_agg_method("min", coerce_strings=True) -sum = _create_nan_agg_method("sum") +max = _create_nan_agg_method("max", coerce_strings=True, invariant_0d=True) +min = _create_nan_agg_method("min", coerce_strings=True, invariant_0d=True) +sum = _create_nan_agg_method("sum", invariant_0d=True) sum.numeric_only = True sum.available_min_count = True std = _create_nan_agg_method("std") std.numeric_only = True var = _create_nan_agg_method("var") var.numeric_only = True -median = _create_nan_agg_method("median", dask_module=dask_array_compat) +median = _create_nan_agg_method( + "median", dask_module=dask_array_compat, invariant_0d=True +) median.numeric_only = True -prod = _create_nan_agg_method("prod") +prod = _create_nan_agg_method("prod", invariant_0d=True) prod.numeric_only = True prod.available_min_count = True -cumprod_1d = _create_nan_agg_method("cumprod") +cumprod_1d = _create_nan_agg_method("cumprod", invariant_0d=True) cumprod_1d.numeric_only = True -cumsum_1d = _create_nan_agg_method("cumsum") +cumsum_1d = _create_nan_agg_method("cumsum", invariant_0d=True) cumsum_1d.numeric_only = True unravel_index = _dask_or_eager_func("unravel_index") -_mean = _create_nan_agg_method("mean") +_mean = _create_nan_agg_method("mean", invariant_0d=True) def _datetime_nanmin(array): @@ -400,27 +417,23 @@ def _datetime_nanmin(array): def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float): """Convert an array containing datetime-like data to numerical values. - Convert the datetime array to a timedelta relative to an offset. - Parameters ---------- - da : array-like - Input data - offset: None, datetime or cftime.datetime - Datetime offset. If None, this is set by default to the array's minimum - value to reduce round off errors. - datetime_unit: {None, Y, M, W, D, h, m, s, ms, us, ns, ps, fs, as} - If not None, convert output to a given datetime unit. Note that some - conversions are not allowed due to non-linear relationships between units. - dtype: dtype - Output dtype. - + array : array-like + Input data + offset : None, datetime or cftime.datetime + Datetime offset. If None, this is set by default to the array's minimum + value to reduce round off errors. + datetime_unit : {None, Y, M, W, D, h, m, s, ms, us, ns, ps, fs, as} + If not None, convert output to a given datetime unit. Note that some + conversions are not allowed due to non-linear relationships between units. + dtype : dtype + Output dtype. Returns ------- array - Numerical representation of datetime object relative to an offset. - + Numerical representation of datetime object relative to an offset. Notes ----- Some datetime unit conversions won't work, for example from days to years, even @@ -463,12 +476,12 @@ def timedelta_to_numeric(value, datetime_unit="ns", dtype=float): Parameters ---------- value : datetime.timedelta, numpy.timedelta64, pandas.Timedelta, str - Time delta representation. + Time delta representation. datetime_unit : {Y, M, W, D, h, m, s, ms, us, ns, ps, fs, as} - The time units of the output values. Note that some conversions are not allowed due to - non-linear relationships between units. + The time units of the output values. Note that some conversions are not allowed due to + non-linear relationships between units. dtype : type - The output data type. + The output data type. """ import datetime as dt @@ -564,7 +577,7 @@ def mean(array, axis=None, skipna=None, **kwargs): return _mean(array, axis=axis, skipna=skipna, **kwargs) -mean.numeric_only = True # type: ignore +mean.numeric_only = True # type: ignore[attr-defined] def _nd_cum_func(cum_func, array, axis, **kwargs): @@ -614,15 +627,15 @@ def last(values, axis, skipna=None): return take(values, -1, axis=axis) -def rolling_window(array, axis, window, center, fill_value): +def sliding_window_view(array, window_shape, axis): """ Make an ndarray with a rolling window of axis-th dimension. The rolling dimension will be placed at the last dimension. """ if is_duck_dask_array(array): - return dask_array_ops.rolling_window(array, axis, window, center, fill_value) - else: # np.ndarray - return nputils.rolling_window(array, axis, window, center, fill_value) + return dask_array_compat.sliding_window_view(array, window_shape, axis) + else: + return npcompat.sliding_window_view(array, window_shape, axis) def least_squares(lhs, rhs, rcond=None, skipna=False): @@ -631,3 +644,12 @@ def least_squares(lhs, rhs, rcond=None, skipna=False): return dask_array_ops.least_squares(lhs, rhs, rcond=rcond, skipna=skipna) else: return nputils.least_squares(lhs, rhs, rcond=rcond, skipna=skipna) + + +def push(array, n, axis): + from bottleneck import push + + if is_duck_dask_array(array): + return dask_array_ops.push(array, n, axis) + else: + return push(array, n, axis) diff --git a/xarray/core/extensions.py b/xarray/core/extensions.py index ee4c3ebc9e6..3debefe2e0d 100644 --- a/xarray/core/extensions.py +++ b/xarray/core/extensions.py @@ -38,7 +38,7 @@ def __get__(self, obj, cls): # __getattr__ on data object will swallow any AttributeErrors # raised when initializing the accessor, so we need to raise as # something else (GH933): - raise RuntimeError("error initializing %r accessor." % self._name) + raise RuntimeError(f"error initializing {self._name!r} accessor.") cache[self._name] = accessor_obj return accessor_obj @@ -48,9 +48,8 @@ def _register_accessor(name, cls): def decorator(accessor): if hasattr(cls, name): warnings.warn( - "registration of accessor %r under name %r for type %r is " - "overriding a preexisting attribute with the same name." - % (accessor, name, cls), + f"registration of accessor {accessor!r} under name {name!r} for type {cls!r} is " + "overriding a preexisting attribute with the same name.", AccessorRegistrationWarning, stacklevel=2, ) @@ -69,7 +68,7 @@ def register_dataarray_accessor(name): Name under which the accessor should be registered. A warning is issued if this name conflicts with a preexisting attribute. - See also + See Also -------- register_dataset_accessor """ @@ -87,7 +86,6 @@ def register_dataset_accessor(name): Examples -------- - In your library code: >>> @xr.register_dataset_accessor("geo") @@ -115,7 +113,7 @@ def register_dataset_accessor(name): (10.0, 5.0) >>> ds.geo.plot() # plots data on a map - See also + See Also -------- register_dataarray_accessor """ diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 0c1be1cc175..7f292605e63 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -11,7 +11,8 @@ from pandas.errors import OutOfBoundsDatetime from .duck_array_ops import array_equiv -from .options import OPTIONS +from .indexing import MemoryCachedArray +from .options import OPTIONS, _get_boolean_with_default from .pycompat import dask_array_type, sparse_array_type from .utils import is_duck_array @@ -189,9 +190,8 @@ def format_array_flat(array, max_width: int): (max_possibly_relevant < array.size) or (cum_len > max_width).any() ): padding = " ... " - count = min( - array.size, max(np.argmax(cum_len + len(padding) - 1 > max_width), 2) - ) + max_len = max(int(np.argmax(cum_len + len(padding) - 1 > max_width)), 2) # type: ignore[type-var] + count = min(array.size, max_len) else: count = array.size padding = "" if (count <= 1) else " " @@ -258,12 +258,12 @@ def inline_variable_array_repr(var, max_width): """Build a one-line summary of a variable's data.""" if var._in_memory: return format_array_flat(var, max_width) + elif hasattr(var._data, "_repr_inline_"): + return var._data._repr_inline_(max_width) elif isinstance(var._data, dask_array_type): return inline_dask_repr(var.data) elif isinstance(var._data, sparse_array_type): return inline_sparse_repr(var.data) - elif hasattr(var._data, "_repr_inline_"): - return var._data._repr_inline_(max_width) elif hasattr(var._data, "__array_function__"): return maybe_truncate(repr(var._data).replace("\n", " "), max_width) else: @@ -372,7 +372,9 @@ def _calculate_col_width(col_items): return col_width -def _mapping_repr(mapping, title, summarizer, col_width=None, max_rows=None): +def _mapping_repr( + mapping, title, summarizer, expand_option_name, col_width=None, max_rows=None +): if col_width is None: col_width = _calculate_col_width(mapping) if max_rows is None: @@ -380,15 +382,19 @@ def _mapping_repr(mapping, title, summarizer, col_width=None, max_rows=None): summary = [f"{title}:"] if mapping: len_mapping = len(mapping) - if len_mapping > max_rows: + if not _get_boolean_with_default(expand_option_name, default=True): + summary = [f"{summary[0]} ({len_mapping})"] + elif len_mapping > max_rows: summary = [f"{summary[0]} ({max_rows}/{len_mapping})"] first_rows = max_rows // 2 + max_rows % 2 - items = list(mapping.items()) - summary += [summarizer(k, v, col_width) for k, v in items[:first_rows]] + keys = list(mapping.keys()) + summary += [summarizer(k, mapping[k], col_width) for k in keys[:first_rows]] if max_rows > 1: last_rows = max_rows // 2 summary += [pretty_print(" ...", col_width) + " ..."] - summary += [summarizer(k, v, col_width) for k, v in items[-last_rows:]] + summary += [ + summarizer(k, mapping[k], col_width) for k in keys[-last_rows:] + ] else: summary += [summarizer(k, v, col_width) for k, v in mapping.items()] else: @@ -397,12 +403,18 @@ def _mapping_repr(mapping, title, summarizer, col_width=None, max_rows=None): data_vars_repr = functools.partial( - _mapping_repr, title="Data variables", summarizer=summarize_datavar + _mapping_repr, + title="Data variables", + summarizer=summarize_datavar, + expand_option_name="display_expand_data_vars", ) attrs_repr = functools.partial( - _mapping_repr, title="Attributes", summarizer=summarize_attr + _mapping_repr, + title="Attributes", + summarizer=summarize_attr, + expand_option_name="display_expand_attrs", ) @@ -410,7 +422,11 @@ def coords_repr(coords, col_width=None): if col_width is None: col_width = _calculate_col_width(_get_col_items(coords)) return _mapping_repr( - coords, title="Coordinates", summarizer=summarize_coord, col_width=col_width + coords, + title="Coordinates", + summarizer=summarize_coord, + expand_option_name="display_expand_coords", + col_width=col_width, ) @@ -488,15 +504,26 @@ def short_data_repr(array): def array_repr(arr): + from .variable import Variable + # used for DataArray, Variable and IndexVariable if hasattr(arr, "name") and arr.name is not None: name_str = f"{arr.name!r} " else: name_str = "" + if ( + isinstance(arr, Variable) + or _get_boolean_with_default("display_expand_data", default=True) + or isinstance(arr.variable._data, MemoryCachedArray) + ): + data_repr = short_data_repr(arr) + else: + data_repr = inline_variable_array_repr(arr.variable, OPTIONS["display_width"]) + summary = [ "".format(type(arr).__name__, name_str, dim_summary(arr)), - short_data_repr(arr), + data_repr, ] if hasattr(arr, "coords"): diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py index 3392aef8da3..2a480427d4e 100644 --- a/xarray/core/formatting_html.py +++ b/xarray/core/formatting_html.py @@ -6,6 +6,7 @@ import pkg_resources from .formatting import inline_variable_array_repr, short_data_repr +from .options import _get_boolean_with_default STATIC_FILES = ("static/html/icons-svg-inline.html", "static/css/style.css") @@ -24,9 +25,8 @@ def short_data_repr_html(array): internal_data = getattr(array, "variable", array)._data if hasattr(internal_data, "_repr_html_"): return internal_data._repr_html_() - else: - text = escape(short_data_repr(array)) - return f"
{text}
" + text = escape(short_data_repr(array)) + return f"
{text}
" def format_dims(dims, coord_names): @@ -38,7 +38,7 @@ def format_dims(dims, coord_names): } dims_li = "".join( - f"
  • " f"{escape(dim)}: {size}
  • " + f"
  • " f"{escape(str(dim))}: {size}
  • " for dim, size in dims.items() ) @@ -47,7 +47,7 @@ def format_dims(dims, coord_names): def summarize_attrs(attrs): attrs_dl = "".join( - f"
    {escape(k)} :
    " f"
    {escape(str(v))}
    " + f"
    {escape(str(k))} :
    " f"
    {escape(str(v))}
    " for k, v in attrs.items() ) @@ -76,8 +76,7 @@ def summarize_coord(name, var): if is_index: coord = var.variable.to_index_variable() if coord.level_names is not None: - coords = {} - coords[name] = _summarize_coord_multiindex(name, coord) + coords = {name: _summarize_coord_multiindex(name, coord)} for lname in coord.level_names: var = coord.get_level_variable(lname) coords[lname] = summarize_variable(lname, var) @@ -164,9 +163,14 @@ def collapsible_section( ) -def _mapping_section(mapping, name, details_func, max_items_collapse, enabled=True): +def _mapping_section( + mapping, name, details_func, max_items_collapse, expand_option_name, enabled=True +): n_items = len(mapping) - collapsed = n_items >= max_items_collapse + expanded = _get_boolean_with_default( + expand_option_name, n_items < max_items_collapse + ) + collapsed = not expanded return collapsible_section( name, @@ -188,7 +192,11 @@ def dim_section(obj): def array_section(obj): # "unique" id to expand/collapse the section data_id = "section-" + str(uuid.uuid4()) - collapsed = "checked" + collapsed = ( + "checked" + if _get_boolean_with_default("display_expand_data", default=True) + else "" + ) variable = getattr(obj, "variable", obj) preview = escape(inline_variable_array_repr(variable, max_width=70)) data_repr = short_data_repr_html(obj) @@ -209,6 +217,7 @@ def array_section(obj): name="Coordinates", details_func=summarize_coords, max_items_collapse=25, + expand_option_name="display_expand_coords", ) @@ -217,6 +226,7 @@ def array_section(obj): name="Data variables", details_func=summarize_vars, max_items_collapse=15, + expand_option_name="display_expand_data_vars", ) @@ -225,6 +235,7 @@ def array_section(obj): name="Attributes", details_func=summarize_attrs, max_items_collapse=10, + expand_option_name="display_expand_attrs", ) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 08532f6d3ed..c2886cc3675 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1,5 +1,4 @@ import datetime -import functools import warnings import numpy as np @@ -7,8 +6,7 @@ from ..plot.plot import _PlotMethods as _DataArray_PlotMethods from . import dtypes, duck_array_ops, nputils, ops -from .arithmetic import SupportsArithmetic -from .common import ImplementsArrayReduce, ImplementsDatasetReduce +from .arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic from .concat import concat from .formatting import format_array_flat from .indexes import propagate_indexes @@ -32,8 +30,8 @@ def check_reduce_dims(reduce_dims, dimensions): reduce_dims = [reduce_dims] if any(dim not in dimensions for dim in reduce_dims): raise ValueError( - "cannot reduce over dimensions %r. expected either '...' to reduce over all dimensions or one or more of %r." - % (reduce_dims, dimensions) + f"cannot reduce over dimensions {reduce_dims!r}. expected either '...' " + f"to reduce over all dimensions or one or more of {dimensions!r}." ) @@ -108,7 +106,7 @@ def _consolidate_slices(slices): last_slice = slice(None) for slice_ in slices: if not isinstance(slice_, slice): - raise ValueError("list element is not a slice: %r" % slice_) + raise ValueError(f"list element is not a slice: {slice_!r}") if ( result and last_slice.stop == slice_.start @@ -144,8 +142,7 @@ def _inverse_permutation_indices(positions): return None positions = [np.arange(sl.start, sl.stop, sl.step) for sl in positions] - indices = nputils.inverse_permutation(np.concatenate(positions)) - return indices + return nputils.inverse_permutation(np.concatenate(positions)) class _DummyGroup: @@ -203,9 +200,8 @@ def _ensure_1d(group, obj): def _unique_and_monotonic(group): if isinstance(group, _DummyGroup): return True - else: - index = safe_cast_to_index(group) - return index.is_unique and index.is_monotonic + index = safe_cast_to_index(group) + return index.is_unique and index.is_monotonic def _apply_loffset(grouper, result): @@ -234,7 +230,7 @@ def _apply_loffset(grouper, result): grouper.loffset = None -class GroupBy(SupportsArithmetic): +class GroupBy: """A object that implements the split-apply-combine pattern. Modeled after `pandas.GroupBy`. The `GroupBy` object can be iterated over @@ -385,7 +381,7 @@ def __init__( if len(group_indices) == 0: if bins is not None: raise ValueError( - "None of the data falls within bins with edges %r" % bins + f"None of the data falls within bins with edges {bins!r}" ) else: raise ValueError( @@ -443,7 +439,7 @@ def __iter__(self): return zip(self._unique_coord.values, self._iter_grouped()) def __repr__(self): - return "{}, grouped over {!r} \n{!r} groups with labels {}.".format( + return "{}, grouped over {!r}\n{!r} groups with labels {}.".format( self.__class__.__name__, self._unique_coord.name, self._unique_coord.size, @@ -481,16 +477,10 @@ def _infer_concat_args(self, applied_example): coord = None return coord, dim, positions - @staticmethod - def _binary_op(f, reflexive=False, **ignored_kwargs): - @functools.wraps(f) - def func(self, other): - g = f if not reflexive else lambda x, y: f(y, x) - applied = self._yield_binary_applied(g, other) - combined = self._combine(applied) - return combined - - return func + def _binary_op(self, other, f, reflexive=False): + g = f if not reflexive else lambda x, y: f(y, x) + applied = self._yield_binary_applied(g, other) + return self._combine(applied) def _yield_binary_applied(self, func, other): dummy = None @@ -508,8 +498,8 @@ def _yield_binary_applied(self, func, other): if self._group.name not in other.dims: raise ValueError( "incompatible dimensions for a grouped " - "binary operation: the group variable %r " - "is not a dimension on the other argument" % self._group.name + f"binary operation: the group variable {self._group.name!r} " + "is not a dimension on the other argument" ) if dummy is None: dummy = _dummy_copy(other) @@ -557,13 +547,12 @@ def fillna(self, value): ------- same type as the grouped object - See also + See Also -------- Dataset.fillna DataArray.fillna """ - out = ops.fillna(self, value) - return out + return ops.fillna(self, value) def quantile( self, q, dim=None, interpolation="linear", keep_attrs=None, skipna=True @@ -606,12 +595,11 @@ def quantile( See Also -------- - numpy.nanquantile, numpy.quantile, pandas.Series.quantile, Dataset.quantile, + numpy.nanquantile, numpy.quantile, pandas.Series.quantile, Dataset.quantile DataArray.quantile Examples -------- - >>> da = xr.DataArray( ... [[1.3, 8.4, 0.7, 6.9], [0.7, 4.2, 9.4, 1.5], [6.5, 7.3, 2.6, 1.9]], ... coords={"x": [0, 0, 1], "y": [1, 1, 2, 2]}, @@ -651,7 +639,7 @@ def quantile( * x (x) int64 0 1 >>> ds.groupby("y").quantile([0, 0.5, 1], dim=...) - Dimensions: (quantile: 3, y: 2) + Dimensions: (y: 2, quantile: 3) Coordinates: * quantile (quantile) float64 0.0 0.5 1.0 * y (y) int64 1 2 @@ -670,7 +658,6 @@ def quantile( keep_attrs=keep_attrs, skipna=skipna, ) - return out def where(self, cond, other=dtypes.NA): @@ -688,7 +675,7 @@ def where(self, cond, other=dtypes.NA): ------- same type as the grouped object - See also + See Also -------- Dataset.where """ @@ -714,7 +701,7 @@ def last(self, skipna=None, keep_attrs=None): def assign_coords(self, coords=None, **coords_kwargs): """Assign coordinates by group. - See also + See Also -------- Dataset.assign_coords Dataset.swap_dims @@ -732,9 +719,11 @@ def _maybe_reorder(xarray_obj, dim, positions): return xarray_obj[{dim: order}] -class DataArrayGroupBy(GroupBy, ImplementsArrayReduce): +class DataArrayGroupBy(GroupBy, DataArrayGroupbyArithmetic): """GroupBy object specialized to grouping DataArray objects""" + __slots__ = () + def _iter_grouped_shortcut(self): """Fast version of `_iter_grouped` that yields Variables without metadata @@ -750,8 +739,7 @@ def _concat_shortcut(self, applied, dim, positions=None): # compiled language) stacked = Variable.concat(applied, dim, shortcut=True) reordered = _maybe_reorder(stacked, dim, positions) - result = self._obj._replace_maybe_drop_dims(reordered) - return result + return self._obj._replace_maybe_drop_dims(reordered) def _restore_dim_order(self, stacked): def lookup_order(dimension): @@ -808,10 +796,7 @@ def map(self, func, shortcut=False, args=(), **kwargs): applied : DataArray or DataArray The result of splitting, applying and combining this array. """ - if shortcut: - grouped = self._iter_grouped_shortcut() - else: - grouped = self._iter_grouped() + grouped = self._iter_grouped_shortcut() if shortcut else self._iter_grouped() applied = (maybe_wrap_array(arr, func(arr, *args, **kwargs)) for arr in grouped) return self._combine(applied, shortcut=shortcut) @@ -903,11 +888,10 @@ def plot(self): return _DataArray_PlotMethods(self) -ops.inject_reduce_methods(DataArrayGroupBy) -ops.inject_binary_ops(DataArrayGroupBy) +class DatasetGroupBy(GroupBy, DatasetGroupbyArithmetic): + __slots__ = () -class DatasetGroupBy(GroupBy, ImplementsDatasetReduce): def map(self, func, args=(), shortcut=None, **kwargs): """Apply a function to each Dataset in the group and concatenate them together into a new Dataset. @@ -1016,7 +1000,7 @@ def reduce_dataset(ds): def assign(self, **kwargs): """Assign data variables by group. - See also + See Also -------- Dataset.assign """ @@ -1027,7 +1011,3 @@ def plot(self): raise NotImplementedError( "Plotting not implemented for DatasetGroupBy objects yet." ) - - -ops.inject_reduce_methods(DatasetGroupBy) -ops.inject_binary_ops(DatasetGroupBy) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index a5d1896e74c..429c37af588 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1,12 +1,420 @@ import collections.abc -from typing import Any, Dict, Hashable, Iterable, Mapping, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Hashable, + Iterable, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) import numpy as np import pandas as pd -from . import formatting -from .utils import is_scalar -from .variable import Variable +from . import formatting, utils +from .indexing import ( + LazilyIndexedArray, + PandasIndexingAdapter, + PandasMultiIndexingAdapter, +) +from .utils import is_dict_like, is_scalar + +if TYPE_CHECKING: + from .variable import IndexVariable, Variable + +IndexVars = Dict[Hashable, "IndexVariable"] + + +class Index: + """Base class inherited by all xarray-compatible indexes.""" + + @classmethod + def from_variables( + cls, variables: Mapping[Hashable, "Variable"] + ) -> Tuple["Index", Optional[IndexVars]]: # pragma: no cover + raise NotImplementedError() + + def to_pandas_index(self) -> pd.Index: + """Cast this xarray index to a pandas.Index object or raise a TypeError + if this is not supported. + + This method is used by all xarray operations that expect/require a + pandas.Index object. + + """ + raise TypeError(f"{type(self)} cannot be cast to a pandas.Index object.") + + def query( + self, labels: Dict[Hashable, Any] + ) -> Tuple[Any, Optional[Tuple["Index", IndexVars]]]: # pragma: no cover + raise NotImplementedError() + + def equals(self, other): # pragma: no cover + raise NotImplementedError() + + def union(self, other): # pragma: no cover + raise NotImplementedError() + + def intersection(self, other): # pragma: no cover + raise NotImplementedError() + + def copy(self, deep: bool = True): # pragma: no cover + raise NotImplementedError() + + def __getitem__(self, indexer: Any): + # if not implemented, index will be dropped from the Dataset or DataArray + raise NotImplementedError() + + +def _sanitize_slice_element(x): + from .dataarray import DataArray + from .variable import Variable + + if not isinstance(x, tuple) and len(np.shape(x)) != 0: + raise ValueError( + f"cannot use non-scalar arrays in a slice for xarray indexing: {x}" + ) + + if isinstance(x, (Variable, DataArray)): + x = x.values + + if isinstance(x, np.ndarray): + x = x[()] + + return x + + +def _query_slice(index, label, coord_name="", method=None, tolerance=None): + if method is not None or tolerance is not None: + raise NotImplementedError( + "cannot use ``method`` argument if any indexers are slice objects" + ) + indexer = index.slice_indexer( + _sanitize_slice_element(label.start), + _sanitize_slice_element(label.stop), + _sanitize_slice_element(label.step), + ) + if not isinstance(indexer, slice): + # unlike pandas, in xarray we never want to silently convert a + # slice indexer into an array indexer + raise KeyError( + "cannot represent labeled-based slice indexer for coordinate " + f"{coord_name!r} with a slice over integer positions; the index is " + "unsorted or non-unique" + ) + return indexer + + +def _asarray_tuplesafe(values): + """ + Convert values into a numpy array of at most 1-dimension, while preserving + tuples. + + Adapted from pandas.core.common._asarray_tuplesafe + """ + if isinstance(values, tuple): + result = utils.to_0d_object_array(values) + else: + result = np.asarray(values) + if result.ndim == 2: + result = np.empty(len(values), dtype=object) + result[:] = values + + return result + + +def _is_nested_tuple(possible_tuple): + return isinstance(possible_tuple, tuple) and any( + isinstance(value, (tuple, list, slice)) for value in possible_tuple + ) + + +def get_indexer_nd(index, labels, method=None, tolerance=None): + """Wrapper around :meth:`pandas.Index.get_indexer` supporting n-dimensional + labels + """ + flat_labels = np.ravel(labels) + flat_indexer = index.get_indexer(flat_labels, method=method, tolerance=tolerance) + indexer = flat_indexer.reshape(labels.shape) + return indexer + + +class PandasIndex(Index): + """Wrap a pandas.Index as an xarray compatible index.""" + + __slots__ = ("index", "dim") + + def __init__(self, array: Any, dim: Hashable): + self.index = utils.safe_cast_to_index(array) + self.dim = dim + + @classmethod + def from_variables(cls, variables: Mapping[Hashable, "Variable"]): + from .variable import IndexVariable + + if len(variables) != 1: + raise ValueError( + f"PandasIndex only accepts one variable, found {len(variables)} variables" + ) + + name, var = next(iter(variables.items())) + + if var.ndim != 1: + raise ValueError( + "PandasIndex only accepts a 1-dimensional variable, " + f"variable {name!r} has {var.ndim} dimensions" + ) + + dim = var.dims[0] + + obj = cls(var.data, dim) + + data = PandasIndexingAdapter(obj.index) + index_var = IndexVariable( + dim, data, attrs=var.attrs, encoding=var.encoding, fastpath=True + ) + + return obj, {name: index_var} + + @classmethod + def from_pandas_index(cls, index: pd.Index, dim: Hashable): + from .variable import IndexVariable + + if index.name is None: + name = dim + index = index.copy() + index.name = dim + else: + name = index.name + + data = PandasIndexingAdapter(index) + index_var = IndexVariable(dim, data, fastpath=True) + + return cls(index, dim), {name: index_var} + + def to_pandas_index(self) -> pd.Index: + return self.index + + def query(self, labels, method=None, tolerance=None): + assert len(labels) == 1 + coord_name, label = next(iter(labels.items())) + + if isinstance(label, slice): + indexer = _query_slice(self.index, label, coord_name, method, tolerance) + elif is_dict_like(label): + raise ValueError( + "cannot use a dict-like object for selection on " + "a dimension that does not have a MultiIndex" + ) + else: + label = ( + label + if getattr(label, "ndim", 1) > 1 # vectorized-indexing + else _asarray_tuplesafe(label) + ) + if label.ndim == 0: + # see https://github.com/pydata/xarray/pull/4292 for details + label_value = label[()] if label.dtype.kind in "mM" else label.item() + if isinstance(self.index, pd.CategoricalIndex): + if method is not None: + raise ValueError( + "'method' is not a valid kwarg when indexing using a CategoricalIndex." + ) + if tolerance is not None: + raise ValueError( + "'tolerance' is not a valid kwarg when indexing using a CategoricalIndex." + ) + indexer = self.index.get_loc(label_value) + else: + indexer = self.index.get_loc( + label_value, method=method, tolerance=tolerance + ) + elif label.dtype.kind == "b": + indexer = label + else: + indexer = get_indexer_nd(self.index, label, method, tolerance) + if np.any(indexer < 0): + raise KeyError(f"not all values found in index {coord_name!r}") + + return indexer, None + + def equals(self, other): + return self.index.equals(other.index) + + def union(self, other): + new_index = self.index.union(other.index) + return type(self)(new_index, self.dim) + + def intersection(self, other): + new_index = self.index.intersection(other.index) + return type(self)(new_index, self.dim) + + def copy(self, deep=True): + return type(self)(self.index.copy(deep=deep), self.dim) + + def __getitem__(self, indexer: Any): + return type(self)(self.index[indexer], self.dim) + + +def _create_variables_from_multiindex(index, dim, level_meta=None): + from .variable import IndexVariable + + if level_meta is None: + level_meta = {} + + variables = {} + + dim_coord_adapter = PandasMultiIndexingAdapter(index) + variables[dim] = IndexVariable( + dim, LazilyIndexedArray(dim_coord_adapter), fastpath=True + ) + + for level in index.names: + meta = level_meta.get(level, {}) + data = PandasMultiIndexingAdapter( + index, dtype=meta.get("dtype"), level=level, adapter=dim_coord_adapter + ) + variables[level] = IndexVariable( + dim, + data, + attrs=meta.get("attrs"), + encoding=meta.get("encoding"), + fastpath=True, + ) + + return variables + + +class PandasMultiIndex(PandasIndex): + @classmethod + def from_variables(cls, variables: Mapping[Hashable, "Variable"]): + if any([var.ndim != 1 for var in variables.values()]): + raise ValueError("PandasMultiIndex only accepts 1-dimensional variables") + + dims = set([var.dims for var in variables.values()]) + if len(dims) != 1: + raise ValueError( + "unmatched dimensions for variables " + + ",".join([str(k) for k in variables]) + ) + + dim = next(iter(dims))[0] + index = pd.MultiIndex.from_arrays( + [var.values for var in variables.values()], names=variables.keys() + ) + obj = cls(index, dim) + + level_meta = { + name: {"dtype": var.dtype, "attrs": var.attrs, "encoding": var.encoding} + for name, var in variables.items() + } + index_vars = _create_variables_from_multiindex( + index, dim, level_meta=level_meta + ) + + return obj, index_vars + + @classmethod + def from_pandas_index(cls, index: pd.MultiIndex, dim: Hashable): + index_vars = _create_variables_from_multiindex(index, dim) + return cls(index, dim), index_vars + + def query(self, labels, method=None, tolerance=None): + if method is not None or tolerance is not None: + raise ValueError( + "multi-index does not support ``method`` and ``tolerance``" + ) + + new_index = None + + # label(s) given for multi-index level(s) + if all([lbl in self.index.names for lbl in labels]): + is_nested_vals = _is_nested_tuple(tuple(labels.values())) + if len(labels) == self.index.nlevels and not is_nested_vals: + indexer = self.index.get_loc(tuple(labels[k] for k in self.index.names)) + else: + for k, v in labels.items(): + # index should be an item (i.e. Hashable) not an array-like + if isinstance(v, Sequence) and not isinstance(v, str): + raise ValueError( + "Vectorized selection is not " + f"available along coordinate {k!r} (multi-index level)" + ) + indexer, new_index = self.index.get_loc_level( + tuple(labels.values()), level=tuple(labels.keys()) + ) + # GH2619. Raise a KeyError if nothing is chosen + if indexer.dtype.kind == "b" and indexer.sum() == 0: + raise KeyError(f"{labels} not found") + + # assume one label value given for the multi-index "array" (dimension) + else: + if len(labels) > 1: + coord_name = next(iter(set(labels) - set(self.index.names))) + raise ValueError( + f"cannot provide labels for both coordinate {coord_name!r} (multi-index array) " + f"and one or more coordinates among {self.index.names!r} (multi-index levels)" + ) + + coord_name, label = next(iter(labels.items())) + + if is_dict_like(label): + invalid_levels = [ + name for name in label if name not in self.index.names + ] + if invalid_levels: + raise ValueError( + f"invalid multi-index level names {invalid_levels}" + ) + return self.query(label) + + elif isinstance(label, slice): + indexer = _query_slice(self.index, label, coord_name) + + elif isinstance(label, tuple): + if _is_nested_tuple(label): + indexer = self.index.get_locs(label) + elif len(label) == self.index.nlevels: + indexer = self.index.get_loc(label) + else: + indexer, new_index = self.index.get_loc_level( + label, level=list(range(len(label))) + ) + + else: + label = ( + label + if getattr(label, "ndim", 1) > 1 # vectorized-indexing + else _asarray_tuplesafe(label) + ) + if label.ndim == 0: + indexer, new_index = self.index.get_loc_level(label.item(), level=0) + elif label.dtype.kind == "b": + indexer = label + else: + if label.ndim > 1: + raise ValueError( + "Vectorized selection is not available along " + f"coordinate {coord_name!r} with a multi-index" + ) + indexer = get_indexer_nd(self.index, label) + if np.any(indexer < 0): + raise KeyError(f"not all values found in index {coord_name!r}") + + if new_index is not None: + if isinstance(new_index, pd.MultiIndex): + new_index, new_vars = PandasMultiIndex.from_pandas_index( + new_index, self.dim + ) + else: + new_index, new_vars = PandasIndex.from_pandas_index(new_index, self.dim) + return indexer, (new_index, new_vars) + else: + return indexer, None def remove_unused_levels_categories(index: pd.Index) -> pd.Index: @@ -47,7 +455,7 @@ def __init__(self, indexes): Parameters ---------- indexes : Dict[Any, pandas.Index] - Indexes held by this object. + Indexes held by this object. """ self._indexes = indexes @@ -68,14 +476,14 @@ def __repr__(self): def default_indexes( - coords: Mapping[Any, Variable], dims: Iterable -) -> Dict[Hashable, pd.Index]: + coords: Mapping[Any, "Variable"], dims: Iterable +) -> Dict[Hashable, Index]: """Default indexes for a Dataset/DataArray. Parameters ---------- coords : Mapping[Any, xarray.Variable] - Coordinate variables from which to draw default indexes. + Coordinate variables from which to draw default indexes. dims : iterable Iterable of dimension names. @@ -84,16 +492,24 @@ def default_indexes( Mapping from indexing keys (levels/dimension names) to indexes used for indexing along that dimension. """ - return {key: coords[key].to_index() for key in dims if key in coords} + return {key: coords[key]._to_xindex() for key in dims if key in coords} def isel_variable_and_index( name: Hashable, - variable: Variable, - index: pd.Index, - indexers: Mapping[Hashable, Union[int, slice, np.ndarray, Variable]], -) -> Tuple[Variable, Optional[pd.Index]]: - """Index a Variable and pandas.Index together.""" + variable: "Variable", + index: Index, + indexers: Mapping[Hashable, Union[int, slice, np.ndarray, "Variable"]], +) -> Tuple["Variable", Optional[Index]]: + """Index a Variable and an Index together. + + If the index cannot be indexed, return None (it will be dropped). + + (note: not compatible yet with xarray flexible indexes). + + """ + from .variable import Variable + if not indexers: # nothing to index return variable.copy(deep=False), index @@ -114,22 +530,28 @@ def isel_variable_and_index( indexer = indexers[dim] if isinstance(indexer, Variable): indexer = indexer.data - new_index = index[indexer] + try: + new_index = index[indexer] + except NotImplementedError: + new_index = None + return new_variable, new_index -def roll_index(index: pd.Index, count: int, axis: int = 0) -> pd.Index: +def roll_index(index: PandasIndex, count: int, axis: int = 0) -> PandasIndex: """Roll an pandas.Index.""" - count %= index.shape[0] + pd_index = index.to_pandas_index() + count %= pd_index.shape[0] if count != 0: - return index[-count:].append(index[:-count]) + new_idx = pd_index[-count:].append(pd_index[:-count]) else: - return index[:] + new_idx = pd_index[:] + return PandasIndex(new_idx, index.dim) def propagate_indexes( - indexes: Optional[Dict[Hashable, pd.Index]], exclude: Optional[Any] = None -) -> Optional[Dict[Hashable, pd.Index]]: + indexes: Optional[Dict[Hashable, Index]], exclude: Optional[Any] = None +) -> Optional[Dict[Hashable, Index]]: """Creates new indexes dict from existing dict optionally excluding some dimensions.""" if exclude is None: exclude = () @@ -140,6 +562,6 @@ def propagate_indexes( if indexes is not None: new_indexes = {k: v for k, v in indexes.items() if k not in exclude} else: - new_indexes = None # type: ignore + new_indexes = None # type: ignore[assignment] return new_indexes diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 843feb04479..70994a36ac8 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -4,7 +4,7 @@ from collections import defaultdict from contextlib import suppress from datetime import timedelta -from typing import Any, Callable, Iterable, Sequence, Tuple, Union +from typing import Any, Callable, Iterable, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -13,11 +13,12 @@ from .npcompat import DTypeLike from .pycompat import ( dask_array_type, + dask_version, integer_types, is_duck_dask_array, sparse_array_type, ) -from .utils import is_dict_like, maybe_cast_to_coords_dtype +from .utils import maybe_cast_to_coords_dtype def expanded_indexer(key, ndim): @@ -54,190 +55,44 @@ def _expand_slice(slice_, size): return np.arange(*slice_.indices(size)) -def _sanitize_slice_element(x): - from .dataarray import DataArray - from .variable import Variable +def group_indexers_by_index(data_obj, indexers, method=None, tolerance=None): + # TODO: benbovy - flexible indexes: indexers are still grouped by dimension + # - Make xarray.Index hashable so that it can be used as key in a mapping? + indexes = {} + grouped_indexers = defaultdict(dict) - if isinstance(x, (Variable, DataArray)): - x = x.values + # TODO: data_obj.xindexes should eventually return the PandasIndex instance + # for each multi-index levels + xindexes = dict(data_obj.xindexes) + for level, dim in data_obj._level_coords.items(): + xindexes[level] = xindexes[dim] - if isinstance(x, np.ndarray): - if x.ndim != 0: - raise ValueError( - f"cannot use non-scalar arrays in a slice for xarray indexing: {x}" - ) - x = x[()] - - return x - - -def _asarray_tuplesafe(values): - """ - Convert values into a numpy array of at most 1-dimension, while preserving - tuples. - - Adapted from pandas.core.common._asarray_tuplesafe - """ - if isinstance(values, tuple): - result = utils.to_0d_object_array(values) - else: - result = np.asarray(values) - if result.ndim == 2: - result = np.empty(len(values), dtype=object) - result[:] = values - - return result - - -def _is_nested_tuple(possible_tuple): - return isinstance(possible_tuple, tuple) and any( - isinstance(value, (tuple, list, slice)) for value in possible_tuple - ) - - -def get_indexer_nd(index, labels, method=None, tolerance=None): - """Wrapper around :meth:`pandas.Index.get_indexer` supporting n-dimensional - labels - """ - flat_labels = np.ravel(labels) - flat_indexer = index.get_indexer(flat_labels, method=method, tolerance=tolerance) - indexer = flat_indexer.reshape(labels.shape) - return indexer - - -def convert_label_indexer(index, label, index_name="", method=None, tolerance=None): - """Given a pandas.Index and labels (e.g., from __getitem__) for one - dimension, return an indexer suitable for indexing an ndarray along that - dimension. If `index` is a pandas.MultiIndex and depending on `label`, - return a new pandas.Index or pandas.MultiIndex (otherwise return None). - """ - new_index = None - - if isinstance(label, slice): - if method is not None or tolerance is not None: - raise NotImplementedError( - "cannot use ``method`` argument if any indexers are slice objects" - ) - indexer = index.slice_indexer( - _sanitize_slice_element(label.start), - _sanitize_slice_element(label.stop), - _sanitize_slice_element(label.step), - ) - if not isinstance(indexer, slice): - # unlike pandas, in xarray we never want to silently convert a - # slice indexer into an array indexer - raise KeyError( - "cannot represent labeled-based slice indexer for dimension " - f"{index_name!r} with a slice over integer positions; the index is " - "unsorted or non-unique" - ) - - elif is_dict_like(label): - is_nested_vals = _is_nested_tuple(tuple(label.values())) - if not isinstance(index, pd.MultiIndex): - raise ValueError( - "cannot use a dict-like object for selection on " - "a dimension that does not have a MultiIndex" - ) - elif len(label) == index.nlevels and not is_nested_vals: - indexer = index.get_loc(tuple(label[k] for k in index.names)) - else: - for k, v in label.items(): - # index should be an item (i.e. Hashable) not an array-like - if isinstance(v, Sequence) and not isinstance(v, str): - raise ValueError( - "Vectorized selection is not " - "available along level variable: " + k - ) - indexer, new_index = index.get_loc_level( - tuple(label.values()), level=tuple(label.keys()) - ) + for key, label in indexers.items(): + try: + index = xindexes[key] + coord = data_obj.coords[key] + dim = coord.dims[0] + if dim not in indexes: + indexes[dim] = index - # GH2619. Raise a KeyError if nothing is chosen - if indexer.dtype.kind == "b" and indexer.sum() == 0: - raise KeyError(f"{label} not found") + label = maybe_cast_to_coords_dtype(label, coord.dtype) + grouped_indexers[dim][key] = label - elif isinstance(label, tuple) and isinstance(index, pd.MultiIndex): - if _is_nested_tuple(label): - indexer = index.get_locs(label) - elif len(label) == index.nlevels: - indexer = index.get_loc(label) - else: - indexer, new_index = index.get_loc_level( - label, level=list(range(len(label))) - ) - else: - label = ( - label - if getattr(label, "ndim", 1) > 1 # vectorized-indexing - else _asarray_tuplesafe(label) - ) - if label.ndim == 0: - # see https://github.com/pydata/xarray/pull/4292 for details - label_value = label[()] if label.dtype.kind in "mM" else label.item() - if isinstance(index, pd.MultiIndex): - indexer, new_index = index.get_loc_level(label_value, level=0) - elif isinstance(index, pd.CategoricalIndex): - if method is not None: - raise ValueError( - "'method' is not a valid kwarg when indexing using a CategoricalIndex." - ) - if tolerance is not None: - raise ValueError( - "'tolerance' is not a valid kwarg when indexing using a CategoricalIndex." - ) - indexer = index.get_loc(label_value) - else: - indexer = index.get_loc(label_value, method=method, tolerance=tolerance) - elif label.dtype.kind == "b": - indexer = label - else: - if isinstance(index, pd.MultiIndex) and label.ndim > 1: + except KeyError: + if key in data_obj.coords: + raise KeyError(f"no index found for coordinate {key}") + elif key not in data_obj.dims: + raise KeyError(f"{key} is not a valid dimension or coordinate") + # key is a dimension without coordinate: we'll reuse the provided labels + elif method is not None or tolerance is not None: raise ValueError( - "Vectorized selection is not available along " - "MultiIndex variable: " + index_name + "cannot supply ``method`` or ``tolerance`` " + "when the indexed dimension does not have " + "an associated coordinate." ) - indexer = get_indexer_nd(index, label, method, tolerance) - if np.any(indexer < 0): - raise KeyError(f"not all values found in index {index_name!r}") - return indexer, new_index + grouped_indexers[None][key] = label - -def get_dim_indexers(data_obj, indexers): - """Given a xarray data object and label based indexers, return a mapping - of label indexers with only dimension names as keys. - - It groups multiple level indexers given on a multi-index dimension - into a single, dictionary indexer for that dimension (Raise a ValueError - if it is not possible). - """ - invalid = [ - k - for k in indexers - if k not in data_obj.dims and k not in data_obj._level_coords - ] - if invalid: - raise ValueError(f"dimensions or multi-index levels {invalid!r} do not exist") - - level_indexers = defaultdict(dict) - dim_indexers = {} - for key, label in indexers.items(): - (dim,) = data_obj[key].dims - if key != dim: - # assume here multi-index level indexer - level_indexers[dim][key] = label - else: - dim_indexers[key] = label - - for dim, level_labels in level_indexers.items(): - if dim_indexers.get(dim, False): - raise ValueError( - "cannot combine multi-index level indexers with an indexer for " - f"dimension {dim}" - ) - dim_indexers[dim] = level_labels - - return dim_indexers + return indexes, grouped_indexers def remap_label_indexers(data_obj, indexers, method=None, tolerance=None): @@ -251,26 +106,25 @@ def remap_label_indexers(data_obj, indexers, method=None, tolerance=None): pos_indexers = {} new_indexes = {} - dim_indexers = get_dim_indexers(data_obj, indexers) - for dim, label in dim_indexers.items(): - try: - index = data_obj.indexes[dim] - except KeyError: - # no index for this dimension: reuse the provided labels - if method is not None or tolerance is not None: - raise ValueError( - "cannot supply ``method`` or ``tolerance`` " - "when the indexed dimension does not have " - "an associated coordinate." - ) + indexes, grouped_indexers = group_indexers_by_index( + data_obj, indexers, method, tolerance + ) + + forward_pos_indexers = grouped_indexers.pop(None, None) + if forward_pos_indexers is not None: + for dim, label in forward_pos_indexers.items(): pos_indexers[dim] = label - else: - coords_dtype = data_obj.coords[dim].dtype - label = maybe_cast_to_coords_dtype(label, coords_dtype) - idxr, new_idx = convert_label_indexer(index, label, dim, method, tolerance) - pos_indexers[dim] = idxr - if new_idx is not None: - new_indexes[dim] = new_idx + + for dim, index in indexes.items(): + labels = grouped_indexers[dim] + idxr, new_idx = index.query(labels, method=method, tolerance=tolerance) + pos_indexers[dim] = idxr + if new_idx is not None: + new_indexes[dim] = new_idx + + # TODO: benbovy - flexible indexes: support the following cases: + # - an index query returns positional indexers over multiple dimensions + # - check/combine positional indexers returned by multiple indexes over the same dimension return pos_indexers, new_indexes @@ -513,7 +367,7 @@ def __getitem__(self, key): return result -class LazilyOuterIndexedArray(ExplicitlyIndexedNDArrayMixin): +class LazilyIndexedArray(ExplicitlyIndexedNDArrayMixin): """Wrap an array to make basic and outer indexing lazy.""" __slots__ = ("array", "key") @@ -589,6 +443,10 @@ def __repr__(self): return f"{type(self).__name__}(array={self.array!r}, key={self.key!r})" +# keep an alias to the old name for external backends pydata/xarray#5111 +LazilyOuterIndexedArray = LazilyIndexedArray + + class LazilyVectorizedIndexedArray(ExplicitlyIndexedNDArrayMixin): """Wrap an array to make vectorized indexing lazy.""" @@ -619,10 +477,10 @@ def _updated_key(self, new_key): return _combine_indexers(self.key, self.shape, new_key) def __getitem__(self, indexer): - # If the indexed array becomes a scalar, return LazilyOuterIndexedArray + # If the indexed array becomes a scalar, return LazilyIndexedArray if all(isinstance(ind, integer_types) for ind in indexer.tuple): key = BasicIndexer(tuple(k[indexer.tuple] for k in self.key.tuple)) - return LazilyOuterIndexedArray(self.array, key) + return LazilyIndexedArray(self.array, key) return type(self)(self.array, self._updated_key(indexer)) def transpose(self, order): @@ -714,7 +572,7 @@ def as_indexable(array): if isinstance(array, np.ndarray): return NumpyIndexingAdapter(array) if isinstance(array, pd.Index): - return PandasIndexAdapter(array) + return PandasIndexingAdapter(array) if isinstance(array, dask_array_type): return DaskIndexingAdapter(array) if hasattr(array, "__array_function__"): @@ -787,11 +645,11 @@ def _combine_indexers(old_key, shape, new_key): Parameters ---------- - old_key: ExplicitIndexer + old_key : ExplicitIndexer The first indexer for the original array - shape: tuple of ints + shape : tuple of ints Shape of the original array to be indexed by old_key - new_key: + new_key The second indexer for indexing original[old_key] """ if not isinstance(old_key, VectorizedIndexer): @@ -841,7 +699,7 @@ def explicit_indexing_adapter( Shape of the indexed array. indexing_support : IndexingSupport enum Form of indexing supported by raw_indexing_method. - raw_indexing_method: callable + raw_indexing_method : callable Function (like ndarray.__getitem__) that when called with indexing key in the form of a tuple returns an indexed array. @@ -895,8 +753,8 @@ def _decompose_vectorized_indexer( Parameters ---------- - indexer: VectorizedIndexer - indexing_support: one of IndexerSupport entries + indexer : VectorizedIndexer + indexing_support : one of IndexerSupport entries Returns ------- @@ -977,8 +835,8 @@ def _decompose_outer_indexer( Parameters ---------- - indexer: OuterIndexer or BasicIndexer - indexing_support: One of the entries of IndexingSupport + indexer : OuterIndexer or BasicIndexer + indexing_support : One of the entries of IndexingSupport Returns ------- @@ -1010,7 +868,7 @@ def _decompose_outer_indexer( return indexer, BasicIndexer(()) assert isinstance(indexer, (OuterIndexer, BasicIndexer)) - backend_indexer = [] + backend_indexer: List[Any] = [] np_indexer = [] # make indexer positive pos_indexer = [] @@ -1094,7 +952,7 @@ def _decompose_outer_indexer( def _arrayize_vectorized_indexer(indexer, shape): - """ Return an identical vindex but slices are replaced by arrays """ + """Return an identical vindex but slices are replaced by arrays""" slices = [v for v in indexer.tuple if isinstance(v, slice)] if len(slices) == 0: return indexer @@ -1376,38 +1234,55 @@ def __getitem__(self, key): return value def __setitem__(self, key, value): - raise TypeError( - "this variable's data is stored in a dask array, " - "which does not support item assignment. To " - "assign to this variable, you must first load it " - "into memory explicitly using the .load() " - "method or accessing its .values attribute." - ) + if dask_version >= "2021.04.1": + if isinstance(key, BasicIndexer): + self.array[key.tuple] = value + elif isinstance(key, VectorizedIndexer): + self.array.vindex[key.tuple] = value + elif isinstance(key, OuterIndexer): + num_non_slices = sum( + 0 if isinstance(k, slice) else 1 for k in key.tuple + ) + if num_non_slices > 1: + raise NotImplementedError( + "xarray can't set arrays with multiple " + "array indices to dask yet." + ) + self.array[key.tuple] = value + else: + raise TypeError( + "This variable's data is stored in a dask array, " + "and the installed dask version does not support item " + "assignment. To assign to this variable, you must either upgrade dask or" + "first load the variable into memory explicitly using the .load() " + "method or accessing its .values attribute." + ) def transpose(self, order): return self.array.transpose(order) -class PandasIndexAdapter(ExplicitlyIndexedNDArrayMixin): +class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin): """Wrap a pandas.Index to preserve dtypes and handle explicit indexing.""" __slots__ = ("array", "_dtype") - def __init__(self, array: Any, dtype: DTypeLike = None): + def __init__(self, array: pd.Index, dtype: DTypeLike = None): self.array = utils.safe_cast_to_index(array) + if dtype is None: if isinstance(array, pd.PeriodIndex): - dtype = np.dtype("O") + dtype_ = np.dtype("O") elif hasattr(array, "categories"): # category isn't a real numpy dtype - dtype = array.categories.dtype + dtype_ = array.categories.dtype elif not utils.is_valid_numpy_dtype(array.dtype): - dtype = np.dtype("O") + dtype_ = np.dtype("O") else: - dtype = array.dtype + dtype_ = array.dtype else: - dtype = np.dtype(dtype) - self._dtype = dtype + dtype_ = np.dtype(dtype) # type: ignore[assignment] + self._dtype = dtype_ @property def dtype(self) -> np.dtype: @@ -1429,7 +1304,13 @@ def shape(self) -> Tuple[int]: def __getitem__( self, indexer - ) -> Union[NumpyIndexingAdapter, np.ndarray, np.datetime64, np.timedelta64]: + ) -> Union[ + "PandasIndexingAdapter", + NumpyIndexingAdapter, + np.ndarray, + np.datetime64, + np.timedelta64, + ]: key = indexer.tuple if isinstance(key, tuple) and len(key) == 1: # unpack key so it can index a pandas.Index object (pandas.Index @@ -1442,7 +1323,7 @@ def __getitem__( result = self.array[key] if isinstance(result, pd.Index): - result = PandasIndexAdapter(result, dtype=self.dtype) + result = type(self)(result, dtype=self.dtype) else: # result is a scalar if result is pd.NaT: @@ -1470,11 +1351,9 @@ def transpose(self, order) -> pd.Index: return self.array # self.array should be always one-dimensional def __repr__(self) -> str: - return "{}(array={!r}, dtype={!r})".format( - type(self).__name__, self.array, self.dtype - ) + return f"{type(self).__name__}(array={self.array!r}, dtype={self.dtype!r})" - def copy(self, deep: bool = True) -> "PandasIndexAdapter": + def copy(self, deep: bool = True) -> "PandasIndexingAdapter": # Not the same as just writing `self.array.copy(deep=deep)`, as # shallow copies of the underlying numpy.ndarrays become deep ones # upon pickling @@ -1483,4 +1362,47 @@ def copy(self, deep: bool = True) -> "PandasIndexAdapter": # >>> len(pickle.dumps((self.array, self.array.copy(deep=False)))) # 8000341 array = self.array.copy(deep=True) if deep else self.array - return PandasIndexAdapter(array, self._dtype) + return type(self)(array, self._dtype) + + +class PandasMultiIndexingAdapter(PandasIndexingAdapter): + """Handles explicit indexing for a pandas.MultiIndex. + + This allows creating one instance for each multi-index level while + preserving indexing efficiency (memoized + might reuse another instance with + the same multi-index). + + """ + + __slots__ = ("array", "_dtype", "level", "adapter") + + def __init__( + self, + array: pd.MultiIndex, + dtype: DTypeLike = None, + level: Optional[str] = None, + adapter: Optional[PandasIndexingAdapter] = None, + ): + super().__init__(array, dtype) + self.level = level + self.adapter = adapter + + def __array__(self, dtype: DTypeLike = None) -> np.ndarray: + if self.level is not None: + return self.array.get_level_values(self.level).values + else: + return super().__array__(dtype) + + @functools.lru_cache(1) + def __getitem__(self, indexer): + if self.adapter is None: + return super().__getitem__(indexer) + else: + return self.adapter.__getitem__(indexer) + + def __repr__(self) -> str: + if self.level is None: + return super().__repr__() + else: + props = "(array={self.array!r}, level={self.level!r}, dtype={self.dtype!r})" + return f"{type(self).__name__}{props}" diff --git a/xarray/core/merge.py b/xarray/core/merge.py index d29a9e1ff02..b8b32bdaa01 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -20,7 +20,8 @@ from . import dtypes, pdcompat from .alignment import deep_align from .duck_array_ops import lazy_array_equiv -from .utils import Frozen, compat_dict_union, dict_equiv +from .indexes import Index, PandasIndex +from .utils import Frozen, compat_dict_union, dict_equiv, equivalent from .variable import Variable, as_variable, assert_unique_multiindex_level_names if TYPE_CHECKING: @@ -56,6 +57,13 @@ ) +class Context: + """object carrying the information of a call""" + + def __init__(self, func): + self.func = func + + def broadcast_dimension_size(variables: List[Variable]) -> Dict[Hashable, int]: """Extract dimension sizes from a dictionary of variables. @@ -65,7 +73,7 @@ def broadcast_dimension_size(variables: List[Variable]) -> Dict[Hashable, int]: for var in variables: for dim, size in zip(var.dims, var.shape): if dim in dims and size != dims[dim]: - raise ValueError("index %r not aligned" % dim) + raise ValueError(f"index {dim!r} not aligned") dims[dim] = size return dims @@ -157,14 +165,15 @@ def _assert_compat_valid(compat): ) -MergeElement = Tuple[Variable, Optional[pd.Index]] +MergeElement = Tuple[Variable, Optional[Index]] def merge_collected( grouped: Dict[Hashable, List[MergeElement]], prioritized: Mapping[Hashable, MergeElement] = None, compat: str = "minimal", -) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, pd.Index]]: + combine_attrs="override", +) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]: """Merge dicts of variables, while resolving conflicts appropriately. Parameters @@ -186,7 +195,7 @@ def merge_collected( _assert_compat_valid(compat) merged_vars: Dict[Hashable, Variable] = {} - merged_indexes: Dict[Hashable, pd.Index] = {} + merged_indexes: Dict[Hashable, Index] = {} for name, elements_list in grouped.items(): if name in prioritized: @@ -209,19 +218,21 @@ def merge_collected( for _, other_index in indexed_elements[1:]: if not index.equals(other_index): raise MergeError( - "conflicting values for index %r on objects to be " - "combined:\nfirst value: %r\nsecond value: %r" - % (name, index, other_index) + f"conflicting values for index {name!r} on objects to be " + f"combined:\nfirst value: {index!r}\nsecond value: {other_index!r}" ) if compat == "identical": for other_variable, _ in indexed_elements[1:]: if not dict_equiv(variable.attrs, other_variable.attrs): raise MergeError( "conflicting attribute values on combined " - "variable %r:\nfirst value: %r\nsecond value: %r" - % (name, variable.attrs, other_variable.attrs) + f"variable {name!r}:\nfirst value: {variable.attrs!r}\nsecond value: {other_variable.attrs!r}" ) merged_vars[name] = variable + merged_vars[name].attrs = merge_attrs( + [var.attrs for var, _ in indexed_elements], + combine_attrs=combine_attrs, + ) merged_indexes[name] = index else: variables = [variable for variable, _ in elements_list] @@ -233,6 +244,11 @@ def merge_collected( # we drop conflicting coordinates) raise + if name in merged_vars: + merged_vars[name].attrs = merge_attrs( + [var.attrs for var in variables], combine_attrs=combine_attrs + ) + return merged_vars, merged_indexes @@ -251,7 +267,7 @@ def collect_variables_and_indexes( from .dataarray import DataArray from .dataset import Dataset - grouped: Dict[Hashable, List[Tuple[Variable, pd.Index]]] = {} + grouped: Dict[Hashable, List[Tuple[Variable, Optional[Index]]]] = {} def append(name, variable, index): values = grouped.setdefault(name, []) @@ -263,13 +279,13 @@ def append_all(variables, indexes): for mapping in list_of_mappings: if isinstance(mapping, Dataset): - append_all(mapping.variables, mapping.indexes) + append_all(mapping.variables, mapping.xindexes) continue for name, variable in mapping.items(): if isinstance(variable, DataArray): coords = variable._coords.copy() # use private API for speed - indexes = dict(variable.indexes) + indexes = dict(variable.xindexes) # explicitly overwritten variables should take precedence coords.pop(name, None) indexes.pop(name, None) @@ -278,7 +294,7 @@ def append_all(variables, indexes): variable = as_variable(variable, name=name) if variable.dims == (name,): variable = variable.to_index_variable() - index = variable.to_index() + index = variable._to_xindex() else: index = None append(name, variable, index) @@ -290,11 +306,11 @@ def collect_from_coordinates( list_of_coords: "List[Coordinates]", ) -> Dict[Hashable, List[MergeElement]]: """Collect variables and indexes to be merged from Coordinate objects.""" - grouped: Dict[Hashable, List[Tuple[Variable, pd.Index]]] = {} + grouped: Dict[Hashable, List[Tuple[Variable, Optional[Index]]]] = {} for coords in list_of_coords: variables = coords.variables - indexes = coords.indexes + indexes = coords.xindexes for name, variable in variables.items(): value = grouped.setdefault(name, []) value.append((variable, indexes.get(name))) @@ -305,7 +321,8 @@ def merge_coordinates_without_align( objects: "List[Coordinates]", prioritized: Mapping[Hashable, MergeElement] = None, exclude_dims: AbstractSet = frozenset(), -) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, pd.Index]]: + combine_attrs: str = "override", +) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]: """Merge variables/indexes from coordinates without automatic alignments. This function is used for merging coordinate from pre-existing xarray @@ -326,7 +343,7 @@ def merge_coordinates_without_align( else: filtered = collected - return merge_collected(filtered, prioritized) + return merge_collected(filtered, prioritized, combine_attrs=combine_attrs) def determine_coords( @@ -438,9 +455,9 @@ def merge_coords( compat: str = "minimal", join: str = "outer", priority_arg: Optional[int] = None, - indexes: Optional[Mapping[Hashable, pd.Index]] = None, + indexes: Optional[Mapping[Hashable, Index]] = None, fill_value: object = dtypes.NA, -) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, pd.Index]]: +) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]: """Merge coordinate variables. See merge_core below for argument descriptions. This works similarly to @@ -474,7 +491,7 @@ def _extract_indexes_from_coords(coords): for name, variable in coords.items(): variable = as_variable(variable, name=name) if variable.dims == (name,): - yield name, variable.to_index() + yield name, variable._to_xindex() def assert_valid_explicit_coords(variables, dims, explicit_coords): @@ -486,19 +503,21 @@ def assert_valid_explicit_coords(variables, dims, explicit_coords): for coord_name in explicit_coords: if coord_name in dims and variables[coord_name].dims != (coord_name,): raise MergeError( - "coordinate %s shares a name with a dataset dimension, but is " + f"coordinate {coord_name} shares a name with a dataset dimension, but is " "not a 1D variable along that dimension. This is disallowed " - "by the xarray data model." % coord_name + "by the xarray data model." ) -def merge_attrs(variable_attrs, combine_attrs): +def merge_attrs(variable_attrs, combine_attrs, context=None): """Combine attributes from different variables according to combine_attrs""" if not variable_attrs: # no attributes to merge return None - if combine_attrs == "drop": + if callable(combine_attrs): + return combine_attrs(variable_attrs, context=context) + elif combine_attrs == "drop": return {} elif combine_attrs == "override": return dict(variable_attrs[0]) @@ -507,23 +526,41 @@ def merge_attrs(variable_attrs, combine_attrs): for attrs in variable_attrs[1:]: try: result = compat_dict_union(result, attrs) - except ValueError: + except ValueError as e: raise MergeError( "combine_attrs='no_conflicts', but some values are not " - "the same. Merging %s with %s" % (str(result), str(attrs)) - ) + f"the same. Merging {str(result)} with {str(attrs)}" + ) from e + return result + elif combine_attrs == "drop_conflicts": + result = {} + dropped_keys = set() + for attrs in variable_attrs: + result.update( + { + key: value + for key, value in attrs.items() + if key not in result and key not in dropped_keys + } + ) + result = { + key: value + for key, value in result.items() + if key not in attrs or equivalent(attrs[key], value) + } + dropped_keys |= {key for key in attrs if key not in result} return result elif combine_attrs == "identical": result = dict(variable_attrs[0]) for attrs in variable_attrs[1:]: if not dict_equiv(result, attrs): raise MergeError( - "combine_attrs='identical', but attrs differ. First is %s " - ", other is %s." % (str(result), str(attrs)) + f"combine_attrs='identical', but attrs differ. First is {str(result)} " + f", other is {str(attrs)}." ) return result else: - raise ValueError("Unrecognised value for combine_attrs=%s" % combine_attrs) + raise ValueError(f"Unrecognised value for combine_attrs={combine_attrs}") class _MergeResult(NamedTuple): @@ -541,7 +578,7 @@ def merge_core( combine_attrs: Optional[str] = "override", priority_arg: Optional[int] = None, explicit_coords: Optional[Sequence] = None, - indexes: Optional[Mapping[Hashable, pd.Index]] = None, + indexes: Optional[Mapping[Hashable, Any]] = None, fill_value: object = dtypes.NA, ) -> _MergeResult: """Core logic for merging labeled objects. @@ -556,14 +593,16 @@ def merge_core( Compatibility checks to use when merging variables. join : {"outer", "inner", "left", "right"}, optional How to combine objects with different indexes. - combine_attrs : {"drop", "identical", "no_conflicts", "override"}, optional + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ + "override"} or callable, default: "override" How to combine attributes of objects priority_arg : int, optional Optional argument in `objects` that takes precedence over the others. explicit_coords : set, optional An explicit list of variables from `objects` that are coordinates. indexes : dict, optional - Dictionary with values given by pandas.Index objects. + Dictionary with values given by xarray.Index objects or anything that + may be cast to pandas.Index objects. fill_value : scalar, optional Value to use for newly missing values @@ -594,7 +633,9 @@ def merge_core( collected = collect_variables_and_indexes(aligned) prioritized = _get_priority_vars_and_indexes(aligned, priority_arg, compat=compat) - variables, out_indexes = merge_collected(collected, prioritized, compat=compat) + variables, out_indexes = merge_collected( + collected, prioritized, compat=compat, combine_attrs=combine_attrs + ) assert_unique_multiindex_level_names(variables) dims = calculate_dimensions(variables) @@ -610,15 +651,11 @@ def merge_core( if ambiguous_coords: raise MergeError( "unable to determine if these variables should be " - "coordinates or not in the merged result: %s" % ambiguous_coords + f"coordinates or not in the merged result: {ambiguous_coords}" ) attrs = merge_attrs( - [ - var.attrs - for var in coerced - if isinstance(var, Dataset) or isinstance(var, DataArray) - ], + [var.attrs for var in coerced if isinstance(var, (Dataset, DataArray))], combine_attrs, ) @@ -630,7 +667,7 @@ def merge( compat: str = "no_conflicts", join: str = "outer", fill_value: object = dtypes.NA, - combine_attrs: str = "drop", + combine_attrs: str = "override", ) -> "Dataset": """Merge any number of xarray objects into a single Dataset as variables. @@ -668,17 +705,23 @@ def merge( Value to use for newly missing values. If a dict-like, maps variable names to fill values. Use a data array's name to refer to its values. - combine_attrs : {"drop", "identical", "no_conflicts", "override"}, \ - default: "drop" - String indicating how to combine attrs of the objects being merged: + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ + "override"} or callable, default: "override" + A callable or a string indicating how to combine attrs of the objects being + merged: - "drop": empty attrs on returned Dataset. - "identical": all attrs must be the same on every object. - "no_conflicts": attrs from all objects are combined, any that have the same name must also have the same value. + - "drop_conflicts": attrs from all objects are combined, any that have + the same name but different values are dropped. - "override": skip comparing and copy attrs from the first dataset to the result. + If a callable, it must expect a sequence of ``attrs`` dicts and a context object + as its only parameters. + Returns ------- Dataset @@ -686,7 +729,6 @@ def merge( Examples -------- - >>> import xarray as xr >>> x = xr.DataArray( ... [[1.0, 2.0], [3.0, 5.0]], ... dims=("lat", "lon"), @@ -839,6 +881,8 @@ def merge( See also -------- concat + combine_nested + combine_by_coords """ from .dataarray import DataArray from .dataset import Dataset @@ -861,8 +905,7 @@ def merge( combine_attrs=combine_attrs, fill_value=fill_value, ) - merged = Dataset._construct_direct(**merge_result._asdict()) - return merged + return Dataset._construct_direct(**merge_result._asdict()) def dataset_merge_method( @@ -872,6 +915,7 @@ def dataset_merge_method( compat: str, join: str, fill_value: Any, + combine_attrs: str, ) -> _MergeResult: """Guts of the Dataset.merge method.""" # we are locked into supporting overwrite_vars for the Dataset.merge @@ -901,7 +945,12 @@ def dataset_merge_method( priority_arg = 2 return merge_core( - objs, compat, join, priority_arg=priority_arg, fill_value=fill_value + objs, + compat, + join, + priority_arg=priority_arg, + fill_value=fill_value, + combine_attrs=combine_attrs, ) @@ -931,10 +980,17 @@ def dataset_update_method( other[key] = value.drop_vars(coord_names) # use ds.coords and not ds.indexes, else str coords are cast to object - indexes = {key: dataset.coords[key] for key in dataset.indexes.keys()} + # TODO: benbovy - flexible indexes: make it work with any xarray index + indexes = {} + for key, index in dataset.xindexes.items(): + if isinstance(index, PandasIndex): + indexes[key] = dataset.coords[key] + else: + indexes[key] = index + return merge_core( [dataset, other], priority_arg=1, - indexes=indexes, + indexes=indexes, # type: ignore combine_attrs="override", ) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 8d112b4603c..36983a227b9 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -11,9 +11,9 @@ from . import utils from .common import _contains_datetime_like_objects, ones_like from .computation import apply_ufunc -from .duck_array_ops import datetime_to_numeric, timedelta_to_numeric -from .options import _get_keep_attrs -from .pycompat import is_duck_dask_array +from .duck_array_ops import datetime_to_numeric, push, timedelta_to_numeric +from .options import OPTIONS, _get_keep_attrs +from .pycompat import dask_version, is_duck_dask_array from .utils import OrderedSet, is_scalar from .variable import Variable, broadcast_variables @@ -93,7 +93,7 @@ def __init__(self, xi, yi, method="linear", fill_value=None, period=None): self._left = fill_value self._right = fill_value else: - raise ValueError("%s is not a valid fill_value" % fill_value) + raise ValueError(f"{fill_value} is not a valid fill_value") def __call__(self, x): return self.f( @@ -154,7 +154,7 @@ def __init__( yi, kind=self.method, fill_value=fill_value, - bounds_error=False, + bounds_error=bounds_error, assume_sorted=assume_sorted, copy=copy, **self.cons_kwargs, @@ -216,20 +216,20 @@ def get_clean_interp_index( Parameters ---------- arr : DataArray - Array to interpolate or fit to a curve. + Array to interpolate or fit to a curve. dim : str - Name of dimension along which to fit. + Name of dimension along which to fit. use_coordinate : str or bool - If use_coordinate is True, the coordinate that shares the name of the - dimension along which interpolation is being performed will be used as the - x values. If False, the x values are set as an equally spaced sequence. + If use_coordinate is True, the coordinate that shares the name of the + dimension along which interpolation is being performed will be used as the + x values. If False, the x values are set as an equally spaced sequence. strict : bool - Whether to raise errors if the index is either non-unique or non-monotonic (default). + Whether to raise errors if the index is either non-unique or non-monotonic (default). Returns ------- Variable - Numerical values for the x-coordinates. + Numerical values for the x-coordinates. Notes ----- @@ -317,9 +317,13 @@ def interp_na( if not is_scalar(max_gap): raise ValueError("max_gap must be a scalar.") + # TODO: benbovy - flexible indexes: update when CFTimeIndex (and DatetimeIndex?) + # has its own class inheriting from xarray.Index if ( - dim in self.indexes - and isinstance(self.indexes[dim], (pd.DatetimeIndex, CFTimeIndex)) + dim in self.xindexes + and isinstance( + self.xindexes[dim].to_pandas_index(), (pd.DatetimeIndex, CFTimeIndex) + ) and use_coordinate ): # Convert to float @@ -390,12 +394,10 @@ def func_interpolate_na(interpolator, y, x, **kwargs): def _bfill(arr, n=None, axis=-1): """inverse of ffill""" - import bottleneck as bn - arr = np.flip(arr, axis=axis) # fill - arr = bn.push(arr, axis=axis, n=n) + arr = push(arr, axis=axis, n=n) # reverse back to original return np.flip(arr, axis=axis) @@ -403,7 +405,11 @@ def _bfill(arr, n=None, axis=-1): def ffill(arr, dim=None, limit=None): """forward fill missing values""" - import bottleneck as bn + if not OPTIONS["use_bottleneck"]: + raise RuntimeError( + "ffill requires bottleneck to be enabled." + " Call `xr.set_options(use_bottleneck=True)` to enable it." + ) axis = arr.get_axis_num(dim) @@ -411,9 +417,9 @@ def ffill(arr, dim=None, limit=None): _limit = limit if limit is not None else arr.shape[axis] return apply_ufunc( - bn.push, + push, arr, - dask="parallelized", + dask="allowed", keep_attrs=True, output_dtypes=[arr.dtype], kwargs=dict(n=_limit, axis=axis), @@ -422,6 +428,12 @@ def ffill(arr, dim=None, limit=None): def bfill(arr, dim=None, limit=None): """backfill missing values""" + if not OPTIONS["use_bottleneck"]: + raise RuntimeError( + "bfill requires bottleneck to be enabled." + " Call `xr.set_options(use_bottleneck=True)` to enable it." + ) + axis = arr.get_axis_num(dim) # work around for bottleneck 178 @@ -430,7 +442,7 @@ def bfill(arr, dim=None, limit=None): return apply_ufunc( _bfill, arr, - dask="parallelized", + dask="allowed", keep_attrs=True, output_dtypes=[arr.dtype], kwargs=dict(n=_limit, axis=axis), @@ -589,16 +601,16 @@ def interp(var, indexes_coords, method, **kwargs): Parameters ---------- - var: Variable - index_coords: + var : Variable + indexes_coords Mapping from dimension name to a pair of original and new coordinates. Original coordinates should be sorted in strictly ascending order. Note that all the coordinates should be Variable objects. - method: string + method : string One of {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'}. For multidimensional interpolation, only {'linear', 'nearest'} can be used. - **kwargs: + **kwargs keyword arguments to be passed to scipy.interpolate Returns @@ -621,10 +633,6 @@ def interp(var, indexes_coords, method, **kwargs): for indexes_coords in decompose_interp(indexes_coords): var = result - # simple speed up for the local interpolation - if method in ["linear", "nearest"]: - var, indexes_coords = _localize(var, indexes_coords) - # target dimensions dims = list(indexes_coords) x, new_x = zip(*[indexes_coords[d] for d in dims]) @@ -658,17 +666,17 @@ def interp_func(var, x, new_x, method, kwargs): Parameters ---------- - var: np.ndarray or dask.array.Array + var : np.ndarray or dask.array.Array Array to be interpolated. The final dimension is interpolated. - x: a list of 1d array. + x : a list of 1d array. Original coordinates. Should not contain NaN. - new_x: a list of 1d array + new_x : a list of 1d array New coordinates. Should not contain NaN. - method: string + method : string {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} for 1-dimensional interpolation. {'linear', 'nearest'} for multidimensional interpolation - **kwargs: + **kwargs Optional keyword arguments to be passed to scipy.interpolator Returns @@ -676,8 +684,8 @@ def interp_func(var, x, new_x, method, kwargs): interpolated: array Interpolated array - Note - ---- + Notes + ----- This requiers scipy installed. See Also @@ -695,21 +703,22 @@ def interp_func(var, x, new_x, method, kwargs): if is_duck_dask_array(var): import dask.array as da - nconst = var.ndim - len(x) + ndim = var.ndim + nconst = ndim - len(x) - out_ind = list(range(nconst)) + list(range(var.ndim, var.ndim + new_x[0].ndim)) + out_ind = list(range(nconst)) + list(range(ndim, ndim + new_x[0].ndim)) # blockwise args format x_arginds = [[_x, (nconst + index,)] for index, _x in enumerate(x)] x_arginds = [item for pair in x_arginds for item in pair] new_x_arginds = [ - [_x, [var.ndim + index for index in range(_x.ndim)]] for _x in new_x + [_x, [ndim + index for index in range(_x.ndim)]] for _x in new_x ] new_x_arginds = [item for pair in new_x_arginds for item in pair] args = ( var, - range(var.ndim), + range(ndim), *x_arginds, *new_x_arginds, ) @@ -721,7 +730,7 @@ def interp_func(var, x, new_x, method, kwargs): new_x = rechunked[1 + (len(rechunked) - 1) // 2 :] new_axes = { - var.ndim + i: new_x[0].chunks[i] + ndim + i: new_x[0].chunks[i] if new_x[0].chunks is not None else new_x[0].shape[i] for i in range(new_x[0].ndim) @@ -737,6 +746,13 @@ def interp_func(var, x, new_x, method, kwargs): else: dtype = var.dtype + if dask_version < "2020.12": + # Using meta and dtype at the same time doesn't work. + # Remove this whenever the minimum requirement for dask is 2020.12: + meta = None + else: + meta = var._meta + return da.blockwise( _dask_aware_interpnd, out_ind, @@ -747,6 +763,8 @@ def interp_func(var, x, new_x, method, kwargs): concatenate=True, dtype=dtype, new_axes=new_axes, + meta=meta, + align_arrays=False, ) return _interpnd(var, x, new_x, func, kwargs) diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 5eb88bcd096..48106bff289 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -3,7 +3,14 @@ import numpy as np from . import dtypes, nputils, utils -from .duck_array_ops import _dask_or_eager_func, count, fillna, isnull, where_method +from .duck_array_ops import ( + _dask_or_eager_func, + count, + fillna, + isnull, + where, + where_method, +) from .pycompat import dask_array_type try: @@ -12,7 +19,7 @@ from . import dask_array_compat except ImportError: dask_array = None - dask_array_compat = None # type: ignore + dask_array_compat = None # type: ignore[assignment] def _replace_nan(a, val): @@ -28,18 +35,14 @@ def _maybe_null_out(result, axis, mask, min_count=1): """ xarray version of pandas.core.nanops._maybe_null_out """ - if axis is not None and getattr(result, "ndim", False): null_mask = (np.take(mask.shape, axis).prod() - mask.sum(axis) - min_count) < 0 - if null_mask.any(): - dtype, fill_value = dtypes.maybe_promote(result.dtype) - result = result.astype(dtype) - result[null_mask] = fill_value + dtype, fill_value = dtypes.maybe_promote(result.dtype) + result = where(null_mask, fill_value, result.astype(dtype)) elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES: null_mask = mask.size - mask.sum() - if null_mask < min_count: - result = np.nan + result = where(null_mask < min_count, np.nan, result) return result @@ -60,7 +63,7 @@ def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs): def _nan_minmax_object(func, fill_value, value, axis=None, **kwargs): - """ In house nanmin and nanmax for object array """ + """In house nanmin and nanmax for object array""" valid_count = count(value, axis=axis) filled_value = fillna(value, fill_value) data = getattr(np, func)(filled_value, axis=axis, **kwargs) @@ -116,7 +119,7 @@ def nansum(a, axis=None, dtype=None, out=None, min_count=None): def _nanmean_ddof_object(ddof, value, axis=None, dtype=None, **kwargs): - """ In house nanmean. ddof argument will be used in _nanvar method """ + """In house nanmean. ddof argument will be used in _nanvar method""" from .duck_array_ops import _dask_or_eager_func, count, fillna, where_method valid_count = count(value, axis=axis) @@ -152,9 +155,7 @@ def nanmedian(a, axis=None, out=None): # possibly blow memory if axis is not None and len(np.atleast_1d(axis)) == a.ndim: axis = None - return _dask_or_eager_func( - "nanmedian", dask_module=dask_array_compat, eager_module=nputils - )(a, axis=axis) + return _dask_or_eager_func("nanmedian", eager_module=nputils)(a, axis=axis) def _nanvar_object(value, axis=None, ddof=0, keepdims=False, **kwargs): diff --git a/xarray/core/npcompat.py b/xarray/core/npcompat.py index 1018332df29..6e22c8cf0a4 100644 --- a/xarray/core/npcompat.py +++ b/xarray/core/npcompat.py @@ -28,69 +28,155 @@ # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import builtins -import operator -from typing import Union +import sys +from distutils.version import LooseVersion +from typing import TYPE_CHECKING, Any, Sequence, TypeVar, Union import numpy as np - -# Vendored from NumPy 1.12; we need a version that support duck typing, even -# on dask arrays with __array_function__ enabled. -def _validate_axis(axis, ndim, argname): - try: - axis = [operator.index(axis)] - except TypeError: - axis = list(axis) - axis = [a + ndim if a < 0 else a for a in axis] - if not builtins.all(0 <= a < ndim for a in axis): - raise ValueError("invalid axis for this array in `%s` argument" % argname) - if len(set(axis)) != len(axis): - raise ValueError("repeated axis in `%s` argument" % argname) - return axis - - -def moveaxis(a, source, destination): - try: - # allow duck-array types if they define transpose - transpose = a.transpose - except AttributeError: - a = np.asarray(a) - transpose = a.transpose - - source = _validate_axis(source, a.ndim, "source") - destination = _validate_axis(destination, a.ndim, "destination") - if len(source) != len(destination): - raise ValueError( - "`source` and `destination` arguments must have " - "the same number of elements" +# Type annotations stubs +try: + from numpy.typing import ArrayLike, DTypeLike +except ImportError: + # fall back for numpy < 1.20, ArrayLike adapted from numpy.typing._array_like + if sys.version_info >= (3, 8): + from typing import Protocol + + HAVE_PROTOCOL = True + else: + try: + from typing_extensions import Protocol + except ImportError: + HAVE_PROTOCOL = False + else: + HAVE_PROTOCOL = True + + if TYPE_CHECKING or HAVE_PROTOCOL: + + class _SupportsArray(Protocol): + def __array__(self) -> np.ndarray: + ... + + else: + _SupportsArray = Any + + _T = TypeVar("_T") + _NestedSequence = Union[ + _T, + Sequence[_T], + Sequence[Sequence[_T]], + Sequence[Sequence[Sequence[_T]]], + Sequence[Sequence[Sequence[Sequence[_T]]]], + ] + _RecursiveSequence = Sequence[Sequence[Sequence[Sequence[Sequence[Any]]]]] + _ArrayLike = Union[ + _NestedSequence[_SupportsArray], + _NestedSequence[_T], + ] + _ArrayLikeFallback = Union[ + _ArrayLike[Union[bool, int, float, complex, str, bytes]], + _RecursiveSequence, + ] + # The extra step defining _ArrayLikeFallback and using ArrayLike as a type + # alias for it works around an issue with mypy. + # The `# type: ignore` below silences the warning of having multiple types + # with the same name (ArrayLike and DTypeLike from the try block) + ArrayLike = _ArrayLikeFallback # type: ignore + # fall back for numpy < 1.20 + DTypeLike = Union[np.dtype, str] # type: ignore[misc] + + +if LooseVersion(np.__version__) >= "1.20.0": + sliding_window_view = np.lib.stride_tricks.sliding_window_view +else: + from numpy.core.numeric import normalize_axis_tuple # type: ignore[attr-defined] + from numpy.lib.stride_tricks import as_strided + + # copied from numpy.lib.stride_tricks + def sliding_window_view( + x, window_shape, axis=None, *, subok=False, writeable=False + ): + """ + Create a sliding window view into the array with the given window shape. + + Also known as rolling or moving window, the window slides across all + dimensions of the array and extracts subsets of the array at all window + positions. + + .. versionadded:: 1.20.0 + + Parameters + ---------- + x : array_like + Array to create the sliding window view from. + window_shape : int or tuple of int + Size of window over each axis that takes part in the sliding window. + If `axis` is not present, must have same length as the number of input + array dimensions. Single integers `i` are treated as if they were the + tuple `(i,)`. + axis : int or tuple of int, optional + Axis or axes along which the sliding window is applied. + By default, the sliding window is applied to all axes and + `window_shape[i]` will refer to axis `i` of `x`. + If `axis` is given as a `tuple of int`, `window_shape[i]` will refer to + the axis `axis[i]` of `x`. + Single integers `i` are treated as if they were the tuple `(i,)`. + subok : bool, optional + If True, sub-classes will be passed-through, otherwise the returned + array will be forced to be a base-class array (default). + writeable : bool, optional + When true, allow writing to the returned view. The default is false, + as this should be used with caution: the returned view contains the + same memory location multiple times, so writing to one location will + cause others to change. + + Returns + ------- + view : ndarray + Sliding window view of the array. The sliding window dimensions are + inserted at the end, and the original dimensions are trimmed as + required by the size of the sliding window. + That is, ``view.shape = x_shape_trimmed + window_shape``, where + ``x_shape_trimmed`` is ``x.shape`` with every entry reduced by one less + than the corresponding window size. + """ + window_shape = ( + tuple(window_shape) if np.iterable(window_shape) else (window_shape,) + ) + # first convert input to array, possibly keeping subclass + x = np.array(x, copy=False, subok=subok) + + window_shape_array = np.array(window_shape) + if np.any(window_shape_array < 0): + raise ValueError("`window_shape` cannot contain negative values") + + if axis is None: + axis = tuple(range(x.ndim)) + if len(window_shape) != len(axis): + raise ValueError( + f"Since axis is `None`, must provide " + f"window_shape for all dimensions of `x`; " + f"got {len(window_shape)} window_shape elements " + f"and `x.ndim` is {x.ndim}." + ) + else: + axis = normalize_axis_tuple(axis, x.ndim, allow_duplicate=True) + if len(window_shape) != len(axis): + raise ValueError( + f"Must provide matching length window_shape and " + f"axis; got {len(window_shape)} window_shape " + f"elements and {len(axis)} axes elements." + ) + + out_strides = x.strides + tuple(x.strides[ax] for ax in axis) + + # note: same axis can be windowed repeatedly + x_shape_trimmed = list(x.shape) + for ax, dim in zip(axis, window_shape): + if x_shape_trimmed[ax] < dim: + raise ValueError("window shape cannot be larger than input array shape") + x_shape_trimmed[ax] -= dim - 1 + out_shape = tuple(x_shape_trimmed) + window_shape + return as_strided( + x, strides=out_strides, shape=out_shape, subok=subok, writeable=writeable ) - - order = [n for n in range(a.ndim) if n not in source] - - for dest, src in sorted(zip(destination, source)): - order.insert(dest, src) - - result = transpose(order) - return result - - -# Type annotations stubs. See also / to be replaced by: -# https://github.com/numpy/numpy/issues/7370 -# https://github.com/numpy/numpy-stubs/ -DTypeLike = Union[np.dtype, str] - - -# from dask/array/utils.py -def _is_nep18_active(): - class A: - def __array_function__(self, *args, **kwargs): - return True - - try: - return np.concatenate([A()]) - except ValueError: - return False - - -IS_NEP18_ACTIVE = _is_nep18_active() diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index c65c22f5384..3e0f550dd30 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -2,7 +2,9 @@ import numpy as np import pandas as pd -from numpy.core.multiarray import normalize_axis_index +from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined] + +from .options import OPTIONS try: import bottleneck as bn @@ -131,81 +133,6 @@ def __setitem__(self, key, value): self._array[key] = np.moveaxis(value, vindex_positions, mixed_positions) -def rolling_window(a, axis, window, center, fill_value): - """ rolling window with padding. """ - pads = [(0, 0) for s in a.shape] - if not hasattr(axis, "__len__"): - axis = [axis] - window = [window] - center = [center] - - for ax, win, cent in zip(axis, window, center): - if cent: - start = int(win / 2) # 10 -> 5, 9 -> 4 - end = win - 1 - start - pads[ax] = (start, end) - else: - pads[ax] = (win - 1, 0) - a = np.pad(a, pads, mode="constant", constant_values=fill_value) - for ax, win in zip(axis, window): - a = _rolling_window(a, win, ax) - return a - - -def _rolling_window(a, window, axis=-1): - """ - Make an ndarray with a rolling window along axis. - - Parameters - ---------- - a : array_like - Array to add rolling window to - axis: int - axis position along which rolling window will be applied. - window : int - Size of rolling window - - Returns - ------- - Array that is a view of the original array with a added dimension - of size w. - - Examples - -------- - >>> x = np.arange(10).reshape((2, 5)) - >>> _rolling_window(x, 3, axis=-1) - array([[[0, 1, 2], - [1, 2, 3], - [2, 3, 4]], - - [[5, 6, 7], - [6, 7, 8], - [7, 8, 9]]]) - - Calculate rolling mean of last dimension: - >>> np.mean(_rolling_window(x, 3, axis=-1), -1) - array([[1., 2., 3.], - [6., 7., 8.]]) - - This function is taken from https://github.com/numpy/numpy/pull/31 - but slightly modified to accept axis option. - """ - axis = normalize_axis_index(axis, a.ndim) - a = np.swapaxes(a, axis, -1) - - if window < 1: - raise ValueError(f"`window` must be at least 1. Given : {window}") - if window > a.shape[-1]: - raise ValueError(f"`window` is too long. Given : {window}") - - shape = a.shape[:-1] + (a.shape[-1] - window + 1, window) - strides = a.strides + (a.strides[-1],) - rolling = np.lib.stride_tricks.as_strided( - a, shape=shape, strides=strides, writeable=False - ) - return np.swapaxes(rolling, -2, axis) - - def _create_bottleneck_method(name, npmodule=np): def f(values, axis=None, **kwargs): dtype = kwargs.get("dtype", None) @@ -213,6 +140,7 @@ def f(values, axis=None, **kwargs): if ( _USE_BOTTLENECK + and OPTIONS["use_bottleneck"] and isinstance(values, np.ndarray) and bn_func is not None and not isinstance(axis, tuple) diff --git a/xarray/core/ops.py b/xarray/core/ops.py index d56b0d59df0..8265035a25c 100644 --- a/xarray/core/ops.py +++ b/xarray/core/ops.py @@ -10,7 +10,6 @@ import numpy as np from . import dtypes, duck_array_ops -from .nputils import array_eq, array_ne try: import bottleneck as bn @@ -22,8 +21,6 @@ has_bottleneck = False -UNARY_OPS = ["neg", "pos", "abs", "invert"] -CMP_BINARY_OPS = ["lt", "le", "ge", "gt"] NUM_BINARY_OPS = [ "add", "sub", @@ -40,9 +37,7 @@ # methods which pass on the numpy return value unchanged # be careful not to list methods that we would want to wrap later NUMPY_SAME_METHODS = ["item", "searchsorted"] -# methods which don't modify the data shape, so the result should still be -# wrapped in an Variable/DataArray -NUMPY_UNARY_METHODS = ["argsort", "clip", "conj", "conjugate"] + # methods which remove an axis REDUCE_METHODS = ["all", "any"] NAN_REDUCE_METHODS = [ @@ -114,23 +109,12 @@ _MINCOUNT_DOCSTRING = """ min_count : int, default: None - The required number of valid values to perform the operation. - If fewer than min_count non-NA values are present the result will - be NA. New in version 0.10.8: Added with the default being None.""" - -_COARSEN_REDUCE_DOCSTRING_TEMPLATE = """\ -Coarsen this object by applying `{name}` along its dimensions. - -Parameters ----------- -**kwargs : dict - Additional keyword arguments passed on to `{name}`. - -Returns -------- -reduced : DataArray or Dataset - New object with `{name}` applied along its coasen dimnensions. -""" + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. New in version 0.10.8: Added with the default being + None. Changed in version 0.17.0: if specified on an integer array + and skipna=True, the result will be a float array.""" def fillna(data, other, join="left", dataset_join="left"): @@ -252,7 +236,7 @@ def func(self, *args, **kwargs): def inject_reduce_methods(cls): methods = ( [ - (name, getattr(duck_array_ops, "array_%s" % name), False) + (name, getattr(duck_array_ops, f"array_{name}"), False) for name in REDUCE_METHODS ] + [(name, getattr(duck_array_ops, name), True) for name in NAN_REDUCE_METHODS] @@ -291,7 +275,7 @@ def inject_cum_methods(cls): def op_str(name): - return "__%s__" % name + return f"__{name}__" def get_op(name): @@ -305,42 +289,44 @@ def inplace_to_noninplace_op(f): return NON_INPLACE_OP[f] -def inject_binary_ops(cls, inplace=False): - for name in CMP_BINARY_OPS + NUM_BINARY_OPS: - setattr(cls, op_str(name), cls._binary_op(get_op(name))) +# _typed_ops.py uses the following wrapped functions as a kind of unary operator +argsort = _method_wrapper("argsort") +conj = _method_wrapper("conj") +conjugate = _method_wrapper("conjugate") +round_ = _func_slash_method_wrapper(duck_array_ops.around, name="round") + + +def inject_numpy_same(cls): + # these methods don't return arrays of the same shape as the input, so + # don't try to patch these in for Dataset objects + for name in NUMPY_SAME_METHODS: + setattr(cls, name, _values_method_wrapper(name)) + + +class IncludeReduceMethods: + __slots__ = () + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) - for name, f in [("eq", array_eq), ("ne", array_ne)]: - setattr(cls, op_str(name), cls._binary_op(f)) + if getattr(cls, "_reduce_method", None): + inject_reduce_methods(cls) - for name in NUM_BINARY_OPS: - # only numeric operations have in-place and reflexive variants - setattr(cls, op_str("r" + name), cls._binary_op(get_op(name), reflexive=True)) - if inplace: - setattr(cls, op_str("i" + name), cls._inplace_binary_op(get_op("i" + name))) +class IncludeCumMethods: + __slots__ = () -def inject_all_ops_and_reduce_methods(cls, priority=50, array_only=True): - # prioritize our operations over those of numpy.ndarray (priority=1) - # and numpy.matrix (priority=10) - cls.__array_priority__ = priority + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) - # patch in standard special operations - for name in UNARY_OPS: - setattr(cls, op_str(name), cls._unary_op(get_op(name))) - inject_binary_ops(cls, inplace=True) + if getattr(cls, "_reduce_method", None): + inject_cum_methods(cls) - # patch in numpy/pandas methods - for name in NUMPY_UNARY_METHODS: - setattr(cls, name, cls._unary_op(_method_wrapper(name))) - f = _func_slash_method_wrapper(duck_array_ops.around, name="round") - setattr(cls, "round", cls._unary_op(f)) +class IncludeNumpySameMethods: + __slots__ = () - if array_only: - # these methods don't return arrays of the same shape as the input, so - # don't try to patch these in for Dataset objects - for name in NUMPY_SAME_METHODS: - setattr(cls, name, _values_method_wrapper(name)) + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) - inject_reduce_methods(cls) - inject_cum_methods(cls) + inject_numpy_same(cls) # some methods not applicable to Dataset objects diff --git a/xarray/core/options.py b/xarray/core/options.py index d421b4c4f17..1d916ff0f7c 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -6,10 +6,15 @@ DISPLAY_MAX_ROWS = "display_max_rows" DISPLAY_STYLE = "display_style" DISPLAY_WIDTH = "display_width" +DISPLAY_EXPAND_ATTRS = "display_expand_attrs" +DISPLAY_EXPAND_COORDS = "display_expand_coords" +DISPLAY_EXPAND_DATA_VARS = "display_expand_data_vars" +DISPLAY_EXPAND_DATA = "display_expand_data" ENABLE_CFTIMEINDEX = "enable_cftimeindex" FILE_CACHE_MAXSIZE = "file_cache_maxsize" KEEP_ATTRS = "keep_attrs" WARN_FOR_UNCLOSED_FILES = "warn_for_unclosed_files" +USE_BOTTLENECK = "use_bottleneck" OPTIONS = { @@ -19,10 +24,15 @@ DISPLAY_MAX_ROWS: 12, DISPLAY_STYLE: "html", DISPLAY_WIDTH: 80, + DISPLAY_EXPAND_ATTRS: "default", + DISPLAY_EXPAND_COORDS: "default", + DISPLAY_EXPAND_DATA_VARS: "default", + DISPLAY_EXPAND_DATA: "default", ENABLE_CFTIMEINDEX: True, FILE_CACHE_MAXSIZE: 128, KEEP_ATTRS: "default", WARN_FOR_UNCLOSED_FILES: False, + USE_BOTTLENECK: True, } _JOIN_OPTIONS = frozenset(["inner", "outer", "left", "right", "exact"]) @@ -38,10 +48,15 @@ def _positive_integer(value): DISPLAY_MAX_ROWS: _positive_integer, DISPLAY_STYLE: _DISPLAY_OPTIONS.__contains__, DISPLAY_WIDTH: _positive_integer, + DISPLAY_EXPAND_ATTRS: lambda choice: choice in [True, False, "default"], + DISPLAY_EXPAND_COORDS: lambda choice: choice in [True, False, "default"], + DISPLAY_EXPAND_DATA_VARS: lambda choice: choice in [True, False, "default"], + DISPLAY_EXPAND_DATA: lambda choice: choice in [True, False, "default"], ENABLE_CFTIMEINDEX: lambda value: isinstance(value, bool), FILE_CACHE_MAXSIZE: _positive_integer, KEEP_ATTRS: lambda choice: choice in [True, False, "default"], WARN_FOR_UNCLOSED_FILES: lambda value: isinstance(value, bool), + USE_BOTTLENECK: lambda value: isinstance(value, bool), } @@ -65,8 +80,8 @@ def _warn_on_setting_enable_cftimeindex(enable_cftimeindex): } -def _get_keep_attrs(default): - global_choice = OPTIONS["keep_attrs"] +def _get_boolean_with_default(option, default): + global_choice = OPTIONS[option] if global_choice == "default": return default @@ -74,10 +89,14 @@ def _get_keep_attrs(default): return global_choice else: raise ValueError( - "The global option keep_attrs must be one of True, False or 'default'." + f"The global option {option} must be one of True, False or 'default'." ) +def _get_keep_attrs(default): + return _get_boolean_with_default("keep_attrs", default) + + class set_options: """Set options for xarray in a controlled context. @@ -85,6 +104,7 @@ class set_options: - ``display_width``: maximum display width for ``repr`` on xarray objects. Default: ``80``. + - ``display_max_rows``: maximum display rows. Default: ``12``. - ``arithmetic_join``: DataArray/Dataset alignment in binary operations. Default: ``'inner'``. - ``file_cache_maxsize``: maximum number of open files to hold in xarray's @@ -105,8 +125,27 @@ class set_options: attrs, ``False`` to always discard them, or ``'default'`` to use original logic that attrs should only be kept in unambiguous circumstances. Default: ``'default'``. + - ``use_bottleneck``: allow using bottleneck. Either ``True`` to accelerate + operations using bottleneck if it is installed or ``False`` to never use it. + Default: ``True`` - ``display_style``: display style to use in jupyter for xarray objects. - Default: ``'text'``. Other options are ``'html'``. + Default: ``'html'``. Other options are ``'text'``. + - ``display_expand_attrs``: whether to expand the attributes section for + display of ``DataArray`` or ``Dataset`` objects. Can be ``True`` to always + expand, ``False`` to always collapse, or ``default`` to expand unless over + a pre-defined limit. Default: ``default``. + - ``display_expand_coords``: whether to expand the coordinates section for + display of ``DataArray`` or ``Dataset`` objects. Can be ``True`` to always + expand, ``False`` to always collapse, or ``default`` to expand unless over + a pre-defined limit. Default: ``default``. + - ``display_expand_data``: whether to expand the data section for display + of ``DataArray`` objects. Can be ``True`` to always expand, ``False`` to + always collapse, or ``default`` to expand unless over a pre-defined limit. + Default: ``default``. + - ``display_expand_data_vars``: whether to expand the data variables section + for display of ``Dataset`` objects. Can be ``True`` to always + expand, ``False`` to always collapse, or ``default`` to expand unless over + a pre-defined limit. Default: ``default``. You can use ``set_options`` either as a context manager: @@ -133,8 +172,7 @@ def __init__(self, **kwargs): for k, v in kwargs.items(): if k not in OPTIONS: raise ValueError( - "argument name %r is not in the set of valid options %r" - % (k, set(OPTIONS)) + f"argument name {k!r} is not in the set of valid options {set(OPTIONS)!r}" ) if k in _VALIDATORS and not _VALIDATORS[k](v): if k == ARITHMETIC_JOIN: diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 20b4b9f9eb3..2c7f4249b5e 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -73,7 +73,7 @@ def check_result_variables( def dataset_to_dataarray(obj: Dataset) -> DataArray: if not isinstance(obj, Dataset): - raise TypeError("Expected Dataset, got %s" % type(obj)) + raise TypeError(f"Expected Dataset, got {type(obj)}") if len(obj.data_vars) > 1: raise TypeError( @@ -102,7 +102,7 @@ def make_meta(obj): """ if isinstance(obj, DataArray): obj_array = obj - obj = obj._to_temp_dataset() + obj = dataarray_to_dataset(obj) elif isinstance(obj, Dataset): obj_array = None else: @@ -116,7 +116,7 @@ def make_meta(obj): meta = meta.set_coords(obj.coords) if obj_array is not None: - return obj_array._from_temp_dataset(meta) + return dataset_to_dataarray(meta) return meta @@ -183,7 +183,6 @@ def map_blocks( This function must return either a single DataArray or a single Dataset. This function cannot add a new chunked dimension. - obj : DataArray, Dataset Passed to the function as its first argument, one block at a time. args : sequence @@ -201,7 +200,6 @@ def map_blocks( When provided, ``attrs`` on variables in `template` are copied over to the result. Any ``attrs`` set by ``func`` will be ignored. - Returns ------- A single DataArray or Dataset with dask backend, reassembled from the outputs of the @@ -210,20 +208,19 @@ def map_blocks( Notes ----- This function is designed for when ``func`` needs to manipulate a whole xarray object - subset to each block. In the more common case where ``func`` can work on numpy arrays, it is - recommended to use ``apply_ufunc``. + subset to each block. Each block is loaded into memory. In the more common case where + ``func`` can work on numpy arrays, it is recommended to use ``apply_ufunc``. If none of the variables in ``obj`` is backed by dask arrays, calling this function is equivalent to calling ``func(obj, *args, **kwargs)``. See Also -------- - dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks, + dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks xarray.DataArray.map_blocks Examples -------- - Calculate an anomaly from climatology using ``.groupby()``. Using ``xr.map_blocks()`` allows for parallel operations with knowledge of ``xarray``, its indices, and its methods like ``.groupby()``. @@ -261,7 +258,7 @@ def map_blocks( ... template=array, ... ) # doctest: +ELLIPSIS - dask.array + dask.array<-calculate_anomaly, shape=(24,), dtype=float64, chunksize=(24,), chunktype=numpy.ndarray> Coordinates: * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 month (time) int64 dask.array @@ -294,11 +291,12 @@ def _wrapper( ) # check that index lengths and values are as expected - for name, index in result.indexes.items(): + for name, index in result.xindexes.items(): if name in expected["shapes"]: - if len(index) != expected["shapes"][name]: + if result.sizes[name] != expected["shapes"][name]: raise ValueError( - f"Received dimension {name!r} of length {len(index)}. Expected length {expected['shapes'][name]}." + f"Received dimension {name!r} of length {result.sizes[name]}. " + f"Expected length {expected['shapes'][name]}." ) if name in expected["indexes"]: expected_index = expected["indexes"][name] @@ -360,29 +358,31 @@ def _wrapper( # check that chunk sizes are compatible input_chunks = dict(npargs[0].chunks) - input_indexes = dict(npargs[0].indexes) + input_indexes = dict(npargs[0].xindexes) for arg in xarray_objs[1:]: assert_chunks_compatible(npargs[0], arg) input_chunks.update(arg.chunks) - input_indexes.update(arg.indexes) + input_indexes.update(arg.xindexes) if template is None: # infer template by providing zero-shaped arrays template = infer_template(func, aligned[0], *args, **kwargs) - template_indexes = set(template.indexes) + template_indexes = set(template.xindexes) preserved_indexes = template_indexes & set(input_indexes) new_indexes = template_indexes - set(input_indexes) indexes = {dim: input_indexes[dim] for dim in preserved_indexes} - indexes.update({k: template.indexes[k] for k in new_indexes}) + indexes.update({k: template.xindexes[k] for k in new_indexes}) output_chunks = { dim: input_chunks[dim] for dim in template.dims if dim in input_chunks } else: # template xarray object has been provided with proper sizes and chunk shapes - indexes = dict(template.indexes) + indexes = dict(template.xindexes) if isinstance(template, DataArray): - output_chunks = dict(zip(template.dims, template.chunks)) # type: ignore + output_chunks = dict( + zip(template.dims, template.chunks) # type: ignore[arg-type] + ) else: output_chunks = dict(template.chunks) @@ -451,7 +451,7 @@ def subset_dataset_to_block( for dim in variable.dims: chunk = chunk[chunk_index[dim]] - chunk_variable_task = (f"{gname}-{name}-{chunk[0]}",) + chunk_tuple + chunk_variable_task = (f"{name}-{gname}-{chunk[0]}",) + chunk_tuple graph[chunk_variable_task] = ( tuple, [variable.dims, chunk, variable.attrs], @@ -465,7 +465,7 @@ def subset_dataset_to_block( } subset = variable.isel(subsetter) chunk_variable_task = ( - "{}-{}".format(gname, dask.base.tokenize(subset)), + f"{name}-{gname}-{dask.base.tokenize(subset)}", ) + chunk_tuple graph[chunk_variable_task] = ( tuple, @@ -500,8 +500,8 @@ def subset_dataset_to_block( expected["shapes"] = { k: output_chunks[k][v] for k, v in chunk_index.items() if k in output_chunks } - expected["data_vars"] = set(template.data_vars.keys()) # type: ignore - expected["coords"] = set(template.coords.keys()) # type: ignore + expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment] + expected["coords"] = set(template.coords.keys()) # type: ignore[assignment] expected["indexes"] = { dim: indexes[dim][_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)] for dim in indexes @@ -515,7 +515,7 @@ def subset_dataset_to_block( for name, variable in template.variables.items(): if name in indexes: continue - gname_l = f"{gname}-{name}" + gname_l = f"{name}-{gname}" var_key_map[name] = gname_l key: Tuple[Any, ...] = (gname_l,) @@ -541,13 +541,23 @@ def subset_dataset_to_block( dependencies=[arg for arg in npargs if dask.is_dask_collection(arg)], ) - for gname_l, layer in new_layers.items(): - # This adds in the getitems for each variable in the dataset. - hlg.dependencies[gname_l] = {gname} - hlg.layers[gname_l] = layer + # This adds in the getitems for each variable in the dataset. + hlg = HighLevelGraph( + {**hlg.layers, **new_layers}, + dependencies={ + **hlg.dependencies, + **{name: {gname} for name in new_layers.keys()}, + }, + ) + + # TODO: benbovy - flexible indexes: make it work with custom indexes + # this will need to pass both indexes and coords to the Dataset constructor + result = Dataset( + coords={k: idx.to_pandas_index() for k, idx in indexes.items()}, + attrs=template.attrs, + ) - result = Dataset(coords=indexes, attrs=template.attrs) - for index in result.indexes: + for index in result.xindexes: result[index].attrs = template[index].attrs result[index].encoding = template[index].encoding @@ -557,8 +567,8 @@ def subset_dataset_to_block( for dim in dims: if dim in output_chunks: var_chunks.append(output_chunks[dim]) - elif dim in indexes: - var_chunks.append((len(indexes[dim]),)) + elif dim in result.xindexes: + var_chunks.append((result.sizes[dim],)) elif dim in template.dims: # new unindexed dimension var_chunks.append((template.sizes[dim],)) @@ -574,5 +584,5 @@ def subset_dataset_to_block( if result_is_array: da = dataset_to_dataarray(result) da.name = template_name - return da # type: ignore - return result # type: ignore + return da # type: ignore[return-value] + return result # type: ignore[return-value] diff --git a/xarray/core/pdcompat.py b/xarray/core/pdcompat.py index f2e22329fc8..ba67cc91f06 100644 --- a/xarray/core/pdcompat.py +++ b/xarray/core/pdcompat.py @@ -46,7 +46,7 @@ Panel = pd.Panel else: - class Panel: # type: ignore + class Panel: # type: ignore[no-redef] pass diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index 8d613038957..d1649235006 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -1,37 +1,63 @@ +from distutils.version import LooseVersion +from importlib import import_module + import numpy as np from .utils import is_duck_array integer_types = (int, np.integer) -try: - import dask.array - from dask.base import is_dask_collection - # solely for isinstance checks - dask_array_type = (dask.array.Array,) +class DuckArrayModule: + """ + Solely for internal isinstance and version checks. - def is_duck_dask_array(x): - return is_duck_array(x) and is_dask_collection(x) + Motivated by having to only import pint when required (as pint currently imports xarray) + https://github.com/pydata/xarray/pull/5561#discussion_r664815718 + """ + + def __init__(self, mod): + try: + duck_array_module = import_module(mod) + duck_array_version = LooseVersion(duck_array_module.__version__) + + if mod == "dask": + duck_array_type = (import_module("dask.array").Array,) + elif mod == "pint": + duck_array_type = (duck_array_module.Quantity,) + elif mod == "cupy": + duck_array_type = (duck_array_module.ndarray,) + elif mod == "sparse": + duck_array_type = (duck_array_module.SparseArray,) + else: + raise NotImplementedError + except ImportError: # pragma: no cover + duck_array_module = None + duck_array_version = LooseVersion("0.0.0") + duck_array_type = () -except ImportError: # pragma: no cover - dask_array_type = () - is_duck_dask_array = lambda _: False - is_dask_collection = lambda _: False + self.module = duck_array_module + self.version = duck_array_version + self.type = duck_array_type + self.available = duck_array_module is not None + + +def is_duck_dask_array(x): + if DuckArrayModule("dask").available: + from dask.base import is_dask_collection + + return is_duck_array(x) and is_dask_collection(x) + else: + return False -try: - # solely for isinstance checks - import sparse - sparse_array_type = (sparse.SparseArray,) -except ImportError: # pragma: no cover - sparse_array_type = () +dsk = DuckArrayModule("dask") +dask_version = dsk.version +dask_array_type = dsk.type -try: - # solely for isinstance checks - import cupy +sp = DuckArrayModule("sparse") +sparse_array_type = sp.type +sparse_version = sp.version - cupy_array_type = (cupy.ndarray,) -except ImportError: # pragma: no cover - cupy_array_type = () +cupy_array_type = DuckArrayModule("cupy").type diff --git a/xarray/core/resample.py b/xarray/core/resample.py index 0a20d918bf1..c7749a7e5ca 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -1,6 +1,5 @@ import warnings -from . import ops from .groupby import DataArrayGroupBy, DatasetGroupBy RESAMPLE_DIM = "__resample_dim__" @@ -248,10 +247,6 @@ def apply(self, func, args=(), shortcut=None, **kwargs): return self.map(func=func, shortcut=shortcut, args=args, **kwargs) -ops.inject_reduce_methods(DataArrayResample) -ops.inject_binary_ops(DataArrayResample) - - class DatasetResample(DatasetGroupBy, Resample): """DatasetGroupBy object specialized to resampling a specified dimension""" @@ -346,7 +341,3 @@ def reduce(self, func, dim=None, keep_attrs=None, **kwargs): removed. """ return super().reduce(func, dim, keep_attrs, **kwargs) - - -ops.inject_reduce_methods(DatasetResample) -ops.inject_binary_ops(DatasetResample) diff --git a/xarray/core/resample_cftime.py b/xarray/core/resample_cftime.py index 882664cbb60..4a413902b90 100644 --- a/xarray/core/resample_cftime.py +++ b/xarray/core/resample_cftime.py @@ -146,7 +146,7 @@ def _get_time_bins(index, freq, closed, label, base): if not isinstance(index, CFTimeIndex): raise TypeError( "index must be a CFTimeIndex, but got " - "an instance of %r" % type(index).__name__ + f"an instance of {type(index).__name__!r}" ) if len(index) == 0: datetime_bins = labels = CFTimeIndex(data=[], name=index.name) @@ -163,11 +163,7 @@ def _get_time_bins(index, freq, closed, label, base): datetime_bins, freq, closed, index, labels ) - if label == "right": - labels = labels[1:] - else: - labels = labels[:-1] - + labels = labels[1:] if label == "right" else labels[:-1] # TODO: when CFTimeIndex supports missing values, if the reference index # contains missing values, insert the appropriate NaN value at the # beginning of the datetime_bins and labels indexes. @@ -262,11 +258,7 @@ def _get_range_edges(first, last, offset, closed="left", base=0): first = normalize_date(first) last = normalize_date(last) - if closed == "left": - first = offset.rollback(first) - else: - first = first - offset - + first = offset.rollback(first) if closed == "left" else first - offset last = last + offset return first, last @@ -321,11 +313,7 @@ def _adjust_dates_anchored(first, last, offset, closed="right", base=0): else: lresult = last else: - if foffset.total_seconds() > 0: - fresult = first - foffset - else: - fresult = first - + fresult = first - foffset if foffset.total_seconds() > 0 else first if loffset.total_seconds() > 0: lresult = last + (offset.as_timedelta() - loffset) else: diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 39d889244dc..0cac9f2b129 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -1,14 +1,15 @@ import functools +import itertools import warnings from typing import Any, Callable, Dict import numpy as np from . import dtypes, duck_array_ops, utils -from .dask_array_ops import dask_rolling_wrapper -from .ops import inject_reduce_methods -from .options import _get_keep_attrs +from .arithmetic import CoarsenArithmetic +from .options import OPTIONS, _get_keep_attrs from .pycompat import is_duck_dask_array +from .utils import either_dict_or_kwargs try: import bottleneck @@ -47,10 +48,10 @@ class Rolling: xarray.DataArray.rolling """ - __slots__ = ("obj", "window", "min_periods", "center", "dim", "keep_attrs") - _attributes = ("window", "min_periods", "center", "dim", "keep_attrs") + __slots__ = ("obj", "window", "min_periods", "center", "dim") + _attributes = ("window", "min_periods", "center", "dim") - def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None): + def __init__(self, obj, windows, min_periods=None, center=False): """ Moving window object. @@ -88,15 +89,6 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None self.min_periods = np.prod(self.window) if min_periods is None else min_periods - if keep_attrs is not None: - warnings.warn( - "Passing ``keep_attrs`` to ``rolling`` is deprecated and will raise an" - " error in xarray 0.18. Please pass ``keep_attrs`` directly to the" - " applied function. Note that keep_attrs is now True per default.", - FutureWarning, - ) - self.keep_attrs = keep_attrs - def __repr__(self): """provide a nice str repr of our rolling object""" @@ -111,8 +103,16 @@ def __repr__(self): def __len__(self): return self.obj.sizes[self.dim] - def _reduce_method(name: str) -> Callable: # type: ignore - array_agg_func = getattr(duck_array_ops, name) + def _reduce_method( # type: ignore[misc] + name: str, fillna, rolling_agg_func: Callable = None + ) -> Callable: + """Constructs reduction methods built on a numpy reduction function (e.g. sum), + a bottleneck reduction function (e.g. move_sum), or a Rolling reduction (_mean).""" + if rolling_agg_func: + array_agg_func = None + else: + array_agg_func = getattr(duck_array_ops, name) + bottleneck_move_func = getattr(bottleneck, "move_" + name, None) def method(self, keep_attrs=None, **kwargs): @@ -120,23 +120,36 @@ def method(self, keep_attrs=None, **kwargs): keep_attrs = self._get_keep_attrs(keep_attrs) return self._numpy_or_bottleneck_reduce( - array_agg_func, bottleneck_move_func, keep_attrs=keep_attrs, **kwargs + array_agg_func, + bottleneck_move_func, + rolling_agg_func, + keep_attrs=keep_attrs, + fillna=fillna, + **kwargs, ) method.__name__ = name method.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name=name) return method - argmax = _reduce_method("argmax") - argmin = _reduce_method("argmin") - max = _reduce_method("max") - min = _reduce_method("min") - mean = _reduce_method("mean") - prod = _reduce_method("prod") - sum = _reduce_method("sum") - std = _reduce_method("std") - var = _reduce_method("var") - median = _reduce_method("median") + def _mean(self, keep_attrs, **kwargs): + result = self.sum(keep_attrs=False, **kwargs) / self.count(keep_attrs=False) + if keep_attrs: + result.attrs = self.obj.attrs + return result + + _mean.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name="mean") + + argmax = _reduce_method("argmax", dtypes.NINF) + argmin = _reduce_method("argmin", dtypes.INF) + max = _reduce_method("max", dtypes.NINF) + min = _reduce_method("min", dtypes.INF) + prod = _reduce_method("prod", 1) + sum = _reduce_method("sum", 0) + mean = _reduce_method("mean", None, _mean) + std = _reduce_method("std", None) + var = _reduce_method("var", None) + median = _reduce_method("median", None) def count(self, keep_attrs=None): keep_attrs = self._get_keep_attrs(keep_attrs) @@ -152,11 +165,10 @@ def _mapping_to_list( if utils.is_dict_like(arg): if allow_default: return [arg.get(d, default) for d in self.dim] - else: - for d in self.dim: - if d not in arg: - raise KeyError(f"argument has no key {d}.") - return [arg[d] for d in self.dim] + for d in self.dim: + if d not in arg: + raise KeyError(f"argument has no key {d}.") + return [arg[d] for d in self.dim] elif allow_allsame: # for single argument return [arg] * len(self.dim) elif len(self.dim) == 1: @@ -167,15 +179,8 @@ def _mapping_to_list( ) def _get_keep_attrs(self, keep_attrs): - if keep_attrs is None: - # TODO: uncomment the next line and remove the others after the deprecation - # keep_attrs = _get_keep_attrs(default=True) - - if self.keep_attrs is None: - keep_attrs = _get_keep_attrs(default=True) - else: - keep_attrs = self.keep_attrs + keep_attrs = _get_keep_attrs(default=True) return keep_attrs @@ -183,7 +188,7 @@ def _get_keep_attrs(self, keep_attrs): class DataArrayRolling(Rolling): __slots__ = ("window_labels",) - def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None): + def __init__(self, obj, windows, min_periods=None, center=False): """ Moving window object for DataArray. You should use DataArray.rolling() method to construct this object @@ -214,9 +219,7 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None xarray.Dataset.rolling xarray.Dataset.groupby """ - super().__init__( - obj, windows, min_periods=min_periods, center=center, keep_attrs=keep_attrs - ) + super().__init__(obj, windows, min_periods=min_periods, center=center) # TODO legacy attribute self.window_labels = self.obj[self.dim[0]] @@ -301,6 +304,24 @@ def construct( """ + return self._construct( + self.obj, + window_dim=window_dim, + stride=stride, + fill_value=fill_value, + keep_attrs=keep_attrs, + **window_dim_kwargs, + ) + + def _construct( + self, + obj, + window_dim=None, + stride=1, + fill_value=dtypes.NA, + keep_attrs=None, + **window_dim_kwargs, + ): from .dataarray import DataArray keep_attrs = self._get_keep_attrs(keep_attrs) @@ -317,18 +338,18 @@ def construct( ) stride = self._mapping_to_list(stride, default=1) - window = self.obj.variable.rolling_window( + window = obj.variable.rolling_window( self.dim, self.window, window_dim, self.center, fill_value=fill_value ) - attrs = self.obj.attrs if keep_attrs else {} + attrs = obj.attrs if keep_attrs else {} result = DataArray( window, - dims=self.obj.dims + tuple(window_dim), - coords=self.obj.coords, + dims=obj.dims + tuple(window_dim), + coords=obj.coords, attrs=attrs, - name=self.obj.name, + name=obj.name, ) return result.isel( **{d: slice(None, None, s) for d, s in zip(self.dim, stride)} @@ -393,7 +414,17 @@ def reduce(self, func, keep_attrs=None, **kwargs): d: utils.get_temp_dimname(self.obj.dims, f"_rolling_dim_{d}") for d in self.dim } - windows = self.construct(rolling_dim, keep_attrs=keep_attrs) + + # save memory with reductions GH4325 + fillna = kwargs.pop("fillna", dtypes.NA) + if fillna is not dtypes.NA: + obj = self.obj.fillna(fillna) + else: + obj = self.obj + windows = self._construct( + obj, rolling_dim, keep_attrs=keep_attrs, fill_value=fillna + ) + result = windows.reduce( func, dim=list(rolling_dim.values()), keep_attrs=keep_attrs, **kwargs ) @@ -454,9 +485,6 @@ def _bottleneck_reduce(self, func, keep_attrs, **kwargs): if is_duck_dask_array(padded.data): raise AssertionError("should not be reachable") - values = dask_rolling_wrapper( - func, padded.data, window=self.window[0], min_count=min_count, axis=axis - ) else: values = func( padded.data, window=self.window[0], min_count=min_count, axis=axis @@ -470,7 +498,13 @@ def _bottleneck_reduce(self, func, keep_attrs, **kwargs): return DataArray(values, self.obj.coords, attrs=attrs, name=self.obj.name) def _numpy_or_bottleneck_reduce( - self, array_agg_func, bottleneck_move_func, keep_attrs, **kwargs + self, + array_agg_func, + bottleneck_move_func, + rolling_agg_func, + keep_attrs, + fillna, + **kwargs, ): if "dim" in kwargs: warnings.warn( @@ -483,7 +517,8 @@ def _numpy_or_bottleneck_reduce( del kwargs["dim"] if ( - bottleneck_move_func is not None + OPTIONS["use_bottleneck"] + and bottleneck_move_func is not None and not is_duck_dask_array(self.obj.data) and len(self.dim) == 1 ): @@ -493,14 +528,23 @@ def _numpy_or_bottleneck_reduce( return self._bottleneck_reduce( bottleneck_move_func, keep_attrs=keep_attrs, **kwargs ) - else: - return self.reduce(array_agg_func, keep_attrs=keep_attrs, **kwargs) + if rolling_agg_func: + return rolling_agg_func(self, keep_attrs=self._get_keep_attrs(keep_attrs)) + if fillna is not None: + if fillna is dtypes.INF: + fillna = dtypes.get_pos_infinity(self.obj.dtype, max_for_int=True) + elif fillna is dtypes.NINF: + fillna = dtypes.get_neg_infinity(self.obj.dtype, min_for_int=True) + kwargs.setdefault("skipna", False) + kwargs.setdefault("fillna", fillna) + + return self.reduce(array_agg_func, keep_attrs=keep_attrs, **kwargs) class DatasetRolling(Rolling): __slots__ = ("rollings",) - def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None): + def __init__(self, obj, windows, min_periods=None, center=False): """ Moving window object for Dataset. You should use Dataset.rolling() method to construct this object @@ -531,7 +575,7 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None xarray.Dataset.groupby xarray.DataArray.groupby """ - super().__init__(obj, windows, min_periods, center, keep_attrs) + super().__init__(obj, windows, min_periods, center) if any(d not in self.obj.dims for d in self.dim): raise KeyError(self.dim) # Keep each Rolling object as a dictionary @@ -544,7 +588,7 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None dims.append(d) center[d] = self.center[i] - if len(dims) > 0: + if dims: w = {d: windows[d] for d in dims} self.rollings[key] = DataArrayRolling(da, w, min_periods, center) @@ -600,13 +644,19 @@ def _counts(self, keep_attrs): ) def _numpy_or_bottleneck_reduce( - self, array_agg_func, bottleneck_move_func, keep_attrs, **kwargs + self, + array_agg_func, + bottleneck_move_func, + rolling_agg_func, + keep_attrs, + **kwargs, ): return self._dataset_implementation( functools.partial( DataArrayRolling._numpy_or_bottleneck_reduce, array_agg_func=array_agg_func, bottleneck_move_func=bottleneck_move_func, + rolling_agg_func=rolling_agg_func, ), keep_attrs=keep_attrs, **kwargs, @@ -661,7 +711,7 @@ def construct( for key, da in self.obj.data_vars.items(): # keeps rollings only for the dataset depending on self.dim dims = [d for d in self.dim if d in da.dims] - if len(dims) > 0: + if dims: wi = {d: window_dim[i] for i, d in enumerate(self.dim) if d in da.dims} st = {d: stride[i] for i, d in enumerate(self.dim) if d in da.dims} @@ -685,7 +735,7 @@ def construct( ) -class Coarsen: +class Coarsen(CoarsenArithmetic): """A object that implements the coarsen. See Also @@ -701,11 +751,10 @@ class Coarsen: "windows", "side", "trim_excess", - "keep_attrs", ) _attributes = ("windows", "side", "trim_excess") - def __init__(self, obj, windows, boundary, side, coord_func, keep_attrs): + def __init__(self, obj, windows, boundary, side, coord_func): """ Moving window object. @@ -721,7 +770,7 @@ def __init__(self, obj, windows, boundary, side, coord_func, keep_attrs): multiple of window size. If 'trim', the excess indexes are trimed. If 'pad', NA will be padded. side : 'left' or 'right' or mapping from dimension to 'left' or 'right' - coord_func: mapping from coordinate name to func. + coord_func : mapping from coordinate name to func. Returns ------- @@ -731,7 +780,6 @@ def __init__(self, obj, windows, boundary, side, coord_func, keep_attrs): self.windows = windows self.side = side self.boundary = boundary - self.keep_attrs = keep_attrs absent_dims = [dim for dim in windows.keys() if dim not in self.obj.dims] if absent_dims: @@ -745,6 +793,12 @@ def __init__(self, obj, windows, boundary, side, coord_func, keep_attrs): coord_func[c] = duck_array_ops.mean self.coord_func = coord_func + def _get_keep_attrs(self, keep_attrs): + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + + return keep_attrs + def __repr__(self): """provide a nice str repr of our coarsen object""" @@ -757,6 +811,109 @@ def __repr__(self): klass=self.__class__.__name__, attrs=",".join(attrs) ) + def construct( + self, + window_dim=None, + keep_attrs=None, + **window_dim_kwargs, + ): + """ + Convert this Coarsen object to a DataArray or Dataset, + where the coarsening dimension is split or reshaped to two + new dimensions. + + Parameters + ---------- + window_dim: mapping + A mapping from existing dimension name to new dimension names. + The size of the second dimension will be the length of the + coarsening window. + keep_attrs: bool, optional + Preserve attributes if True + **window_dim_kwargs : {dim: new_name, ...} + The keyword arguments form of ``window_dim``. + + Returns + ------- + Dataset or DataArray with reshaped dimensions + + Examples + -------- + >>> da = xr.DataArray(np.arange(24), dims="time") + >>> da.coarsen(time=12).construct(time=("year", "month")) + + array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]]) + Dimensions without coordinates: year, month + + See Also + -------- + DataArrayRolling.construct + DatasetRolling.construct + """ + + from .dataarray import DataArray + from .dataset import Dataset + + window_dim = either_dict_or_kwargs( + window_dim, window_dim_kwargs, "Coarsen.construct" + ) + if not window_dim: + raise ValueError( + "Either window_dim or window_dim_kwargs need to be specified." + ) + + bad_new_dims = tuple( + win + for win, dims in window_dim.items() + if len(dims) != 2 or isinstance(dims, str) + ) + if bad_new_dims: + raise ValueError( + f"Please provide exactly two dimension names for the following coarsening dimensions: {bad_new_dims}" + ) + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + + missing_dims = set(window_dim) - set(self.windows) + if missing_dims: + raise ValueError( + f"'window_dim' must contain entries for all dimensions to coarsen. Missing {missing_dims}" + ) + extra_windows = set(self.windows) - set(window_dim) + if extra_windows: + raise ValueError( + f"'window_dim' includes dimensions that will not be coarsened: {extra_windows}" + ) + + reshaped = Dataset() + if isinstance(self.obj, DataArray): + obj = self.obj._to_temp_dataset() + else: + obj = self.obj + + reshaped.attrs = obj.attrs if keep_attrs else {} + + for key, var in obj.variables.items(): + reshaped_dims = tuple( + itertools.chain(*[window_dim.get(dim, [dim]) for dim in list(var.dims)]) + ) + if reshaped_dims != var.dims: + windows = {w: self.windows[w] for w in window_dim if w in var.dims} + reshaped_var, _ = var.coarsen_reshape(windows, self.boundary, self.side) + attrs = var.attrs if keep_attrs else {} + reshaped[key] = (reshaped_dims, reshaped_var, attrs) + else: + reshaped[key] = var + + should_be_coords = set(window_dim) & set(self.obj.coords) + result = reshaped.set_coords(should_be_coords) + if isinstance(self.obj, DataArray): + return self.obj._from_temp_dataset(result) + else: + return result + class DataArrayCoarsen(Coarsen): __slots__ = () @@ -764,7 +921,9 @@ class DataArrayCoarsen(Coarsen): _reduce_extra_args_docstring = """""" @classmethod - def _reduce_method(cls, func: Callable, include_skipna: bool, numeric_only: bool): + def _reduce_method( + cls, func: Callable, include_skipna: bool = False, numeric_only: bool = False + ): """ Return a wrapped function for injecting reduction methods. see ops.inject_reduce_methods @@ -773,11 +932,13 @@ def _reduce_method(cls, func: Callable, include_skipna: bool, numeric_only: bool if include_skipna: kwargs["skipna"] = None - def wrapped_func(self, **kwargs): + def wrapped_func(self, keep_attrs: bool = None, **kwargs): from .dataarray import DataArray + keep_attrs = self._get_keep_attrs(keep_attrs) + reduced = self.obj.variable.coarsen( - self.windows, func, self.boundary, self.side, self.keep_attrs, **kwargs + self.windows, func, self.boundary, self.side, keep_attrs, **kwargs ) coords = {} for c, v in self.obj.coords.items(): @@ -790,15 +951,53 @@ def wrapped_func(self, **kwargs): self.coord_func[c], self.boundary, self.side, - self.keep_attrs, + keep_attrs, **kwargs, ) else: coords[c] = v - return DataArray(reduced, dims=self.obj.dims, coords=coords) + return DataArray( + reduced, dims=self.obj.dims, coords=coords, name=self.obj.name + ) return wrapped_func + def reduce(self, func: Callable, keep_attrs: bool = None, **kwargs): + """Reduce the items in this group by applying `func` along some + dimension(s). + + Parameters + ---------- + func : callable + Function which can be called in the form `func(x, axis, **kwargs)` + to return the result of collapsing an np.ndarray over the coarsening + dimensions. It must be possible to provide the `axis` argument + with a tuple of integers. + keep_attrs : bool, default: None + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False, the new object will be returned + without attributes. If None uses the global default. + **kwargs : dict + Additional keyword arguments passed on to `func`. + + Returns + ------- + reduced : DataArray + Array with summarized data. + + Examples + -------- + >>> da = xr.DataArray(np.arange(8).reshape(2, 4), dims=("a", "b")) + >>> coarsen = da.coarsen(b=2) + >>> coarsen.reduce(np.sum) + + array([[ 1, 5], + [ 9, 13]]) + Dimensions without coordinates: a, b + """ + wrapped_func = self._reduce_method(func) + return wrapped_func(self, keep_attrs=keep_attrs, **kwargs) + class DatasetCoarsen(Coarsen): __slots__ = () @@ -806,7 +1005,9 @@ class DatasetCoarsen(Coarsen): _reduce_extra_args_docstring = """""" @classmethod - def _reduce_method(cls, func: Callable, include_skipna: bool, numeric_only: bool): + def _reduce_method( + cls, func: Callable, include_skipna: bool = False, numeric_only: bool = False + ): """ Return a wrapped function for injecting reduction methods. see ops.inject_reduce_methods @@ -815,10 +1016,12 @@ def _reduce_method(cls, func: Callable, include_skipna: bool, numeric_only: bool if include_skipna: kwargs["skipna"] = None - def wrapped_func(self, **kwargs): + def wrapped_func(self, keep_attrs: bool = None, **kwargs): from .dataset import Dataset - if self.keep_attrs: + keep_attrs = self._get_keep_attrs(keep_attrs) + + if keep_attrs: attrs = self.obj.attrs else: attrs = {} @@ -826,25 +1029,53 @@ def wrapped_func(self, **kwargs): reduced = {} for key, da in self.obj.data_vars.items(): reduced[key] = da.variable.coarsen( - self.windows, func, self.boundary, self.side, **kwargs + self.windows, + func, + self.boundary, + self.side, + keep_attrs=keep_attrs, + **kwargs, ) coords = {} for c, v in self.obj.coords.items(): - if any(d in self.windows for d in v.dims): - coords[c] = v.variable.coarsen( - self.windows, - self.coord_func[c], - self.boundary, - self.side, - **kwargs, - ) - else: - coords[c] = v.variable + # variable.coarsen returns variables not containing the window dims + # unchanged (maybe removes attrs) + coords[c] = v.variable.coarsen( + self.windows, + self.coord_func[c], + self.boundary, + self.side, + keep_attrs=keep_attrs, + **kwargs, + ) + return Dataset(reduced, coords=coords, attrs=attrs) return wrapped_func + def reduce(self, func: Callable, keep_attrs=None, **kwargs): + """Reduce the items in this group by applying `func` along some + dimension(s). + + Parameters + ---------- + func : callable + Function which can be called in the form `func(x, axis, **kwargs)` + to return the result of collapsing an np.ndarray over the coarsening + dimensions. It must be possible to provide the `axis` argument with + a tuple of integers. + keep_attrs : bool, default: None + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False, the new object will be returned + without attributes. If None uses the global default. + **kwargs : dict + Additional keyword arguments passed on to `func`. -inject_reduce_methods(DataArrayCoarsen) -inject_reduce_methods(DatasetCoarsen) + Returns + ------- + reduced : Dataset + Arrays with summarized data. + """ + wrapped_func = self._reduce_method(func) + return wrapped_func(self, keep_attrs=keep_attrs, **kwargs) diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index 0ae85a870e8..e0fe57a9fb0 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING, Generic, Hashable, Mapping, Optional, TypeVar +from distutils.version import LooseVersion +from typing import TYPE_CHECKING, Generic, Hashable, Mapping, TypeVar, Union import numpy as np @@ -26,12 +27,25 @@ def move_exp_nanmean(array, *, axis, alpha): raise TypeError("rolling_exp is not currently support for dask-like arrays") import numbagg + # No longer needed in numbag > 0.2.0; remove in time if axis == (): return array.astype(np.float64) else: return numbagg.move_exp_nanmean(array, axis=axis, alpha=alpha) +def move_exp_nansum(array, *, axis, alpha): + if is_duck_dask_array(array): + raise TypeError("rolling_exp is not currently supported for dask-like arrays") + import numbagg + + # numbagg <= 0.2.0 did not have a __version__ attribute + if LooseVersion(getattr(numbagg, "__version__", "0.1.0")) < LooseVersion("0.2.0"): + raise ValueError("`rolling_exp(...).sum() requires numbagg>=0.2.1.") + + return numbagg.move_exp_nansum(array, axis=axis, alpha=alpha) + + def _get_center_of_mass(comass, span, halflife, alpha): """ Vendored from pandas.core.window.common._get_center_of_mass @@ -74,7 +88,7 @@ class RollingExp(Generic[T_DSorDA]): ---------- obj : Dataset or DataArray Object to window. - windows : mapping of hashable to int + windows : mapping of hashable to int (or float for alpha type) A mapping from the name of the dimension to create the rolling exponential window along (e.g. `time`) to the size of the moving window. window_type : {"span", "com", "halflife", "alpha"}, default: "span" @@ -90,7 +104,7 @@ class RollingExp(Generic[T_DSorDA]): def __init__( self, obj: T_DSorDA, - windows: Mapping[Hashable, int], + windows: Mapping[Hashable, Union[int, float]], window_type: str = "span", ): self.obj: T_DSorDA = obj @@ -98,9 +112,9 @@ def __init__( self.dim = dim self.alpha = _get_alpha(**{window_type: window}) - def mean(self, keep_attrs: Optional[bool] = None) -> T_DSorDA: + def mean(self, keep_attrs: bool = None) -> T_DSorDA: """ - Exponentially weighted moving average + Exponentially weighted moving average. Parameters ---------- @@ -124,3 +138,30 @@ def mean(self, keep_attrs: Optional[bool] = None) -> T_DSorDA: return self.obj.reduce( move_exp_nanmean, dim=self.dim, alpha=self.alpha, keep_attrs=keep_attrs ) + + def sum(self, keep_attrs: bool = None) -> T_DSorDA: + """ + Exponentially weighted moving sum. + + Parameters + ---------- + keep_attrs : bool, default: None + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False, the new object will be returned + without attributes. If None uses the global default. + + Examples + -------- + >>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x") + >>> da.rolling_exp(x=2, window_type="span").sum() + + array([1. , 1.33333333, 2.44444444, 2.81481481, 2.9382716 ]) + Dimensions without coordinates: x + """ + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + + return self.obj.reduce( + move_exp_nansum, dim=self.dim, alpha=self.alpha, keep_attrs=keep_attrs + ) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index ced688f32dd..a139d2ef10a 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -4,11 +4,13 @@ import functools import io import itertools -import os.path +import os import re +import sys import warnings from enum import Enum from typing import ( + TYPE_CHECKING, Any, Callable, Collection, @@ -31,8 +33,6 @@ import numpy as np import pandas as pd -from . import dtypes - K = TypeVar("K") V = TypeVar("V") T = TypeVar("T") @@ -83,6 +83,7 @@ def maybe_coerce_to_str(index, original_coords): pd.Index uses object-dtype to store str - try to avoid this for coords """ + from . import dtypes try: result_type = dtypes.result_type(*original_coords) @@ -108,6 +109,8 @@ def safe_cast_to_index(array: Any) -> pd.Index: index = array elif hasattr(array, "to_index"): index = array.to_index() + elif hasattr(array, "to_pandas_index"): + index = array.to_pandas_index() else: kwargs = {} if hasattr(array, "dtype") and array.dtype.kind == "O": @@ -219,7 +222,7 @@ def update_safety_check( if k in first_dict and not compat(v, first_dict[k]): raise ValueError( "unsafe to merge dictionaries without " - "overriding values; conflicting key %r" % k + f"overriding values; conflicting key {k!r}" ) @@ -255,7 +258,7 @@ def is_full_slice(value: Any) -> bool: def is_list_like(value: Any) -> bool: - return isinstance(value, list) or isinstance(value, tuple) + return isinstance(value, (list, tuple)) def is_duck_array(value: Any) -> bool: @@ -275,28 +278,21 @@ def either_dict_or_kwargs( kw_kwargs: Mapping[str, T], func_name: str, ) -> Mapping[Hashable, T]: - if pos_kwargs is not None: - if not is_dict_like(pos_kwargs): - raise ValueError( - "the first argument to .%s must be a dictionary" % func_name - ) - if kw_kwargs: - raise ValueError( - "cannot specify both keyword and positional " - "arguments to .%s" % func_name - ) - return pos_kwargs - else: + if pos_kwargs is None: # Need an explicit cast to appease mypy due to invariance; see # https://github.com/python/mypy/issues/6228 return cast(Mapping[Hashable, T], kw_kwargs) + if not is_dict_like(pos_kwargs): + raise ValueError(f"the first argument to .{func_name} must be a dictionary") + if kw_kwargs: + raise ValueError( + f"cannot specify both keyword and positional arguments to .{func_name}" + ) + return pos_kwargs -def is_scalar(value: Any, include_0d: bool = True) -> bool: - """Whether to treat a value as a scalar. - Any non-iterable, string, or 0-D array - """ +def _is_scalar(value, include_0d): from .variable import NON_NUMPY_SUPPORTED_ARRAY_TYPES if include_0d: @@ -311,6 +307,37 @@ def is_scalar(value: Any, include_0d: bool = True) -> bool: ) +# See GH5624, this is a convoluted way to allow type-checking to use `TypeGuard` without +# requiring typing_extensions as a required dependency to _run_ the code (it is required +# to type-check). +try: + if sys.version_info >= (3, 10): + from typing import TypeGuard + else: + from typing_extensions import TypeGuard +except ImportError: + if TYPE_CHECKING: + raise + else: + + def is_scalar(value: Any, include_0d: bool = True) -> bool: + """Whether to treat a value as a scalar. + + Any non-iterable, string, or 0-D array + """ + return _is_scalar(value, include_0d) + + +else: + + def is_scalar(value: Any, include_0d: bool = True) -> TypeGuard[Hashable]: + """Whether to treat a value as a scalar. + + Any non-iterable, string, or 0-D array + """ + return _is_scalar(value, include_0d) + + def is_valid_numpy_dtype(dtype: Any) -> bool: try: np.dtype(dtype) @@ -359,10 +386,7 @@ def dict_equiv( for k in first: if k not in second or not compat(first[k], second[k]): return False - for k in second: - if k not in first: - return False - return True + return all(k in first for k in second) def compat_dict_intersection( @@ -482,40 +506,6 @@ def __len__(self) -> int: return len(self._keys) -class SortedKeysDict(MutableMapping[K, V]): - """An wrapper for dictionary-like objects that always iterates over its - items in sorted order by key but is otherwise equivalent to the underlying - mapping. - """ - - __slots__ = ("mapping",) - - def __init__(self, mapping: MutableMapping[K, V] = None): - self.mapping = {} if mapping is None else mapping - - def __getitem__(self, key: K) -> V: - return self.mapping[key] - - def __setitem__(self, key: K, value: V) -> None: - self.mapping[key] = value - - def __delitem__(self, key: K) -> None: - del self.mapping[key] - - def __iter__(self) -> Iterator[K]: - # see #4571 for the reason of the type ignore - return iter(sorted(self.mapping)) # type: ignore - - def __len__(self) -> int: - return len(self.mapping) - - def __contains__(self, key: object) -> bool: - return key in self.mapping - - def __repr__(self) -> str: - return "{}({!r})".format(type(self).__name__, self.mapping) - - class OrderedSet(MutableSet[T]): """A simple ordered set. @@ -645,10 +635,15 @@ def close_on_error(f): def is_remote_uri(path: str) -> bool: - return bool(re.search(r"^https?\://", path)) + """Finds URLs of the form protocol:// or protocol:: + + This also matches for http[s]://, which were the only remote URLs + supported in <=v0.16.2. + """ + return bool(re.search(r"^[a-z][a-z0-9]*(\://|\:\:)", path)) -def read_magic_number(filename_or_obj, count=8): +def read_magic_number_from_file(filename_or_obj, count=8) -> bytes: # check byte header to determine file type if isinstance(filename_or_obj, bytes): magic_number = filename_or_obj[:count] @@ -659,16 +654,34 @@ def read_magic_number(filename_or_obj, count=8): "file-like object read/write pointer not at the start of the file, " "please close and reopen, or use a context manager" ) - magic_number = filename_or_obj.read(count) + magic_number = filename_or_obj.read(count) # type: ignore filename_or_obj.seek(0) else: raise TypeError(f"cannot read the magic number form {type(filename_or_obj)}") return magic_number -def is_grib_path(path: str) -> bool: - _, ext = os.path.splitext(path) - return ext in [".grib", ".grb", ".grib2", ".grb2"] +def try_read_magic_number_from_path(pathlike, count=8) -> Optional[bytes]: + if isinstance(pathlike, str) or hasattr(pathlike, "__fspath__"): + path = os.fspath(pathlike) + try: + with open(path, "rb") as f: + return read_magic_number_from_file(f, count) + except (FileNotFoundError, TypeError): + pass + return None + + +def try_read_magic_number_from_file_or_path( + filename_or_obj, count=8 +) -> Optional[bytes]: + magic_number = try_read_magic_number_from_path(filename_or_obj, count) + if magic_number is None: + try: + magic_number = read_magic_number_from_file(filename_or_obj, count) + except TypeError: + pass + return magic_number def is_uniform_spaced(arr, **kwargs) -> bool: @@ -695,10 +708,6 @@ def hashable(v: Any) -> bool: return True -def not_implemented(*args, **kwargs): - return NotImplemented - - def decode_numpy_dict_values(attrs: Mapping[K, V]) -> Dict[K, V]: """Convert attribute values from numpy objects to native Python objects, for use in to_dict @@ -735,7 +744,7 @@ def __init__(self, data: MutableMapping[K, V], hidden_keys: Iterable[K]): def _raise_if_hidden(self, key: K) -> None: if key in self._hidden_keys: - raise KeyError("Key `%r` is hidden." % key) + raise KeyError(f"Key `{key!r}` is hidden.") # The next five methods are requirements of the ABC. def __setitem__(self, key: K, value: V) -> None: @@ -868,7 +877,7 @@ def drop_missing_dims( """ if missing_dims == "raise": - supplied_dims_set = set(val for val in supplied_dims if val is not ...) + supplied_dims_set = {val for val in supplied_dims if val is not ...} invalid = supplied_dims_set - set(dims) if invalid: raise ValueError( @@ -920,3 +929,11 @@ class Default(Enum): _default = Default.token + + +def iterate_nested(nested_list): + for item in nested_list: + if isinstance(item, list): + yield from iterate_nested(item) + else: + yield item diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 64c1895da59..6b971389de7 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1,11 +1,9 @@ import copy -import functools import itertools import numbers import warnings from collections import defaultdict from datetime import timedelta -from distutils.version import LooseVersion from typing import ( Any, Dict, @@ -24,23 +22,28 @@ import xarray as xr # only for Dataset and DataArray -from . import arithmetic, common, dtypes, duck_array_ops, indexing, nputils, ops, utils +from . import common, dtypes, duck_array_ops, indexing, nputils, ops, utils +from .arithmetic import VariableArithmetic +from .common import AbstractArray +from .indexes import PandasIndex, PandasMultiIndex from .indexing import ( BasicIndexer, OuterIndexer, - PandasIndexAdapter, + PandasIndexingAdapter, VectorizedIndexer, as_indexable, ) -from .npcompat import IS_NEP18_ACTIVE -from .options import _get_keep_attrs +from .options import OPTIONS, _get_keep_attrs from .pycompat import ( + DuckArrayModule, cupy_array_type, dask_array_type, integer_types, is_duck_dask_array, + sparse_array_type, ) from .utils import ( + NdimSizeLenMixin, OrderedSet, _default, decode_numpy_dict_values, @@ -61,7 +64,7 @@ + cupy_array_type ) # https://github.com/python/mypy/issues/224 -BASIC_INDEXING_TYPES = integer_types + (slice,) # type: ignore +BASIC_INDEXING_TYPES = integer_types + (slice,) VariableType = TypeVar("VariableType", bound="Variable") """Type annotation to be used when methods of Variable return self or a copy of self. @@ -120,6 +123,11 @@ def as_variable(obj, name=None) -> "Union[Variable, IndexVariable]": if isinstance(obj, Variable): obj = obj.copy(deep=False) elif isinstance(obj, tuple): + if isinstance(obj[1], DataArray): + raise TypeError( + "Using a DataArray object to construct a variable is" + " ambiguous, please extract the data using the .data property." + ) try: obj = Variable(*obj) except (TypeError, ValueError) as error: @@ -139,25 +147,24 @@ def as_variable(obj, name=None) -> "Union[Variable, IndexVariable]": data = as_compatible_data(obj) if data.ndim != 1: raise MissingDimensionsError( - "cannot set variable %r with %r-dimensional data " + f"cannot set variable {name!r} with {data.ndim!r}-dimensional data " "without explicit dimension names. Pass a tuple of " - "(dims, data) instead." % (name, data.ndim) + "(dims, data) instead." ) obj = Variable(name, data, fastpath=True) else: raise TypeError( "unable to convert object into a variable without an " - "explicit list of dimensions: %r" % obj + f"explicit list of dimensions: {obj!r}" ) if name is not None and name in obj.dims: # convert the Variable into an Index if obj.ndim != 1: raise MissingDimensionsError( - "%r has more than 1-dimension and the same name as one of its " - "dimensions %r. xarray disallows such variables because they " - "conflict with the coordinates used to label " - "dimensions." % (name, obj.dims) + f"{name!r} has more than 1-dimension and the same name as one of its " + f"dimensions {obj.dims!r}. xarray disallows such variables because they " + "conflict with the coordinates used to label dimensions." ) obj = obj.to_index_variable() @@ -169,11 +176,11 @@ def _maybe_wrap_data(data): Put pandas.Index and numpy.ndarray arguments in adapter objects to ensure they can be indexed properly. - NumpyArrayAdapter, PandasIndexAdapter and LazilyOuterIndexedArray should + NumpyArrayAdapter, PandasIndexingAdapter and LazilyIndexedArray should all pass through unmodified. """ if isinstance(data, pd.Index): - return PandasIndexAdapter(data) + return PandasIndexingAdapter(data) return data @@ -218,7 +225,8 @@ def as_compatible_data(data, fastpath=False): data = np.timedelta64(getattr(data, "value", data), "ns") # we don't want nested self-described arrays - data = getattr(data, "values", data) + if isinstance(data, (pd.Series, pd.Index, pd.DataFrame)): + data = data.values if isinstance(data, np.ma.MaskedArray): mask = np.ma.getmaskarray(data) @@ -229,30 +237,14 @@ def as_compatible_data(data, fastpath=False): else: data = np.asarray(data) - if not isinstance(data, np.ndarray): - if hasattr(data, "__array_function__"): - if IS_NEP18_ACTIVE: - return data - else: - raise TypeError( - "Got an NumPy-like array type providing the " - "__array_function__ protocol but NEP18 is not enabled. " - "Check that numpy >= v1.16 and that the environment " - 'variable "NUMPY_EXPERIMENTAL_ARRAY_FUNCTION" is set to ' - '"1"' - ) + if not isinstance(data, np.ndarray) and hasattr(data, "__array_function__"): + return data # validate whether the data is valid data types. data = np.asarray(data) - if isinstance(data, np.ndarray): - if data.dtype.kind == "O": - data = _possibly_convert_objects(data) - elif data.dtype.kind == "M": - data = _possibly_convert_objects(data) - elif data.dtype.kind == "m": - data = _possibly_convert_objects(data) - + if isinstance(data, np.ndarray) and data.dtype.kind in "OMm": + data = _possibly_convert_objects(data) return _maybe_wrap_data(data) @@ -270,10 +262,7 @@ def _as_array_or_item(data): TODO: remove this (replace with np.asarray) once these issues are fixed """ - if isinstance(data, cupy_array_type): - data = data.get() - else: - data = np.asarray(data) + data = np.asarray(data) if data.ndim == 0: if data.dtype.kind == "M": data = np.datetime64(data, "ns") @@ -282,9 +271,7 @@ def _as_array_or_item(data): return data -class Variable( - common.AbstractArray, arithmetic.SupportsArithmetic, utils.NdimSizeLenMixin -): +class Variable(AbstractArray, NdimSizeLenMixin, VariableArithmetic): """A netcdf-like variable consisting of dimensions, data and attributes which describe a single Array. A single Variable object is not fully described outside the context of its parent Dataset (if you want such a @@ -350,7 +337,9 @@ def nbytes(self): @property def _in_memory(self): - return isinstance(self._data, (np.ndarray, np.number, PandasIndexAdapter)) or ( + return isinstance( + self._data, (np.ndarray, np.number, PandasIndexingAdapter) + ) or ( isinstance(self._data, indexing.MemoryCachedArray) and isinstance(self._data.array, indexing.NumpyIndexingAdapter) ) @@ -403,7 +392,6 @@ def astype( * 'same_kind' means only safe casts or casts within a kind, like float64 to float32, are allowed. * 'unsafe' means any data conversions may be done. - subok : bool, optional If True, then sub-classes will be passed-through, otherwise the returned array will be forced to be a base-class array. @@ -428,7 +416,7 @@ def astype( Make sure to only supply these arguments if the underlying array class supports them. - See also + See Also -------- numpy.ndarray.astype dask.array.Array.astype @@ -521,22 +509,15 @@ def __dask_scheduler__(self): def __dask_postcompute__(self): array_func, array_args = self._data.__dask_postcompute__() - return ( - self._dask_finalize, - (array_func, array_args, self._dims, self._attrs, self._encoding), - ) + return self._dask_finalize, (array_func,) + array_args def __dask_postpersist__(self): array_func, array_args = self._data.__dask_postpersist__() - return ( - self._dask_finalize, - (array_func, array_args, self._dims, self._attrs, self._encoding), - ) + return self._dask_finalize, (array_func,) + array_args - @staticmethod - def _dask_finalize(results, array_func, array_args, dims, attrs, encoding): - data = array_func(results, *array_args) - return Variable(dims, data, attrs=attrs, encoding=encoding) + def _dask_finalize(self, results, array_func, *args, **kwargs): + data = array_func(results, *args, **kwargs) + return Variable(self._dims, data, attrs=self._attrs, encoding=self._encoding) @property def values(self): @@ -563,6 +544,18 @@ def to_index_variable(self): to_coord = utils.alias(to_index_variable, "to_coord") + def _to_xindex(self): + # temporary function used internally as a replacement of to_index() + # returns an xarray Index instance instead of a pd.Index instance + index_var = self.to_index_variable() + index = index_var.to_index() + dim = index_var.dims[0] + + if isinstance(index, pd.MultiIndex): + return PandasMultiIndex(index, dim) + else: + return PandasIndex(index, dim) + def to_index(self): """Convert this variable to a pandas.Index""" return self.to_index_variable().to_index() @@ -591,8 +584,8 @@ def _parse_dimensions(self, dims): dims = tuple(dims) if len(dims) != self.ndim: raise ValueError( - "dimensions %s must have the same length as the " - "number of data dimensions, ndim=%s" % (dims, self.ndim) + f"dimensions {dims} must have the same length as the " + f"number of data dimensions, ndim={self.ndim}" ) return dims @@ -606,8 +599,8 @@ def _broadcast_indexes(self, key): """Prepare an indexing key for an indexing operation. Parameters - ----------- - key: int, slice, array-like, dict or tuple of integer, slice and array-like + ---------- + key : int, slice, array-like, dict or tuple of integer, slice and array-like Any valid input for indexing. Returns @@ -667,11 +660,9 @@ def _broadcast_indexes_basic(self, key): return dims, BasicIndexer(key), None def _validate_indexers(self, key): - """ Make sanity checks """ + """Make sanity checks""" for dim, k in zip(self.dims, key): - if isinstance(k, BASIC_INDEXING_TYPES): - pass - else: + if not isinstance(k, BASIC_INDEXING_TYPES): if not isinstance(k, Variable): k = np.asarray(k) if k.ndim > 1: @@ -722,7 +713,7 @@ def _broadcast_indexes_outer(self, key): return dims, OuterIndexer(tuple(new_key)), None def _nonzero(self): - """ Equivalent numpy's nonzero but returns a tuple of Varibles. """ + """Equivalent numpy's nonzero but returns a tuple of Varibles.""" # TODO we should replace dask's native nonzero # after https://github.com/dask/dask/issues/1076 is implemented. nonzeros = np.nonzero(self.data) @@ -800,12 +791,12 @@ def __getitem__(self: VariableType, key) -> VariableType: dims, indexer, new_order = self._broadcast_indexes(key) data = as_indexable(self._data)[indexer] if new_order: - data = duck_array_ops.moveaxis(data, range(len(new_order)), new_order) + data = np.moveaxis(data, range(len(new_order)), new_order) return self._finalize_indexing_result(dims, data) def _finalize_indexing_result(self: VariableType, dims, data) -> VariableType: """Used by IndexVariable to return IndexVariable objects when possible.""" - return type(self)(dims, data, self._attrs, self._encoding, fastpath=True) + return self._replace(dims=dims, data=data) def _getitem_with_mask(self, key, fill_value=dtypes.NA): """Index this Variable with -1 remapped to fill_value.""" @@ -859,9 +850,8 @@ def __setitem__(self, key, value): value = as_compatible_data(value) if value.ndim > len(dims): raise ValueError( - "shape mismatch: value array of shape %s could not be " - "broadcast to indexing result with %s dimensions" - % (value.shape, len(dims)) + f"shape mismatch: value array of shape {value.shape} could not be " + f"broadcast to indexing result with {len(dims)} dimensions" ) if value.ndim == 0: value = Variable((), value) @@ -873,7 +863,7 @@ def __setitem__(self, key, value): if new_order: value = duck_array_ops.asarray(value) value = value[(len(dims) - value.ndim) * (np.newaxis,) + (Ellipsis,)] - value = duck_array_ops.moveaxis(value, new_order, range(len(new_order))) + value = np.moveaxis(value, new_order, range(len(new_order))) indexable = as_indexable(self._data) indexable[index_tuple] = value @@ -929,7 +919,6 @@ def copy(self, deep=True, data=None): Examples -------- - Shallow copy versus deep copy >>> var = xr.Variable(data=[1, 2, 3], dims="x") @@ -985,8 +974,12 @@ def copy(self, deep=True, data=None): return self._replace(data=data) def _replace( - self, dims=_default, data=_default, attrs=_default, encoding=_default - ) -> "Variable": + self: VariableType, + dims=_default, + data=_default, + attrs=_default, + encoding=_default, + ) -> VariableType: if dims is _default: dims = copy.copy(self._dims) if data is _default: @@ -1007,7 +1000,7 @@ def __deepcopy__(self, memo=None): # mutable objects should not be hashable # https://github.com/python/mypy/issues/4266 - __hash__ = None # type: ignore + __hash__ = None # type: ignore[assignment] @property def chunks(self): @@ -1045,7 +1038,6 @@ def chunk(self, chunks={}, name=None, lock=False): ------- chunked : xarray.Variable """ - import dask import dask.array as da if chunks is None: @@ -1075,12 +1067,10 @@ def chunk(self, chunks={}, name=None, lock=False): data = indexing.ImplicitToExplicitIndexingAdapter( data, indexing.OuterIndexer ) - if LooseVersion(dask.__version__) < "2.0.0": - kwargs = {} - else: - # All of our lazily loaded backend array classes should use NumPy - # array operations. - kwargs = {"meta": np.ndarray} + + # All of our lazily loaded backend array classes should use NumPy + # array operations. + kwargs = {"meta": np.ndarray} else: kwargs = {} @@ -1089,7 +1079,31 @@ def chunk(self, chunks={}, name=None, lock=False): data = da.from_array(data, chunks, name=name, lock=lock, **kwargs) - return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True) + return self._replace(data=data) + + def to_numpy(self) -> np.ndarray: + """Coerces wrapped data to numpy and returns a numpy.ndarray""" + # TODO an entrypoint so array libraries can choose coercion method? + data = self.data + + # TODO first attempt to call .to_numpy() once some libraries implement it + if isinstance(data, dask_array_type): + data = data.compute() + if isinstance(data, cupy_array_type): + data = data.get() + # pint has to be imported dynamically as pint imports xarray + pint_array_type = DuckArrayModule("pint").type + if isinstance(data, pint_array_type): + data = data.magnitude + if isinstance(data, sparse_array_type): + data = data.todense() + data = np.asarray(data) + + return data + + def as_numpy(self: VariableType) -> VariableType: + """Coerces wrapped data into a numpy array, returning a Variable.""" + return self._replace(data=self.to_numpy()) def _as_sparse(self, sparse_format=_default, fill_value=dtypes.NA): """ @@ -1213,7 +1227,7 @@ def _shift_one_dim(self, dim, count, fill_value=dtypes.NA): # TODO: remove this once dask.array automatically aligns chunks data = data.rechunk(self.data.chunks) - return type(self)(self.dims, data, self._attrs, fastpath=True) + return self._replace(data=data) def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): """ @@ -1225,7 +1239,7 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): Integer offset to shift along each of the given dimensions. Positive offsets shift to the right; negative offsets shift to the left. - fill_value: scalar, optional + fill_value : scalar, optional Value to use for newly missing values **shifts_kwargs The keyword arguments form of ``shifts``. @@ -1325,7 +1339,7 @@ def pad( # workaround for bug in Dask's default value of stat_length https://github.com/dask/dask/issues/5303 if stat_length is None and mode in ["maximum", "mean", "median", "minimum"]: - stat_length = [(n, n) for n in self.data.shape] # type: ignore + stat_length = [(n, n) for n in self.data.shape] # type: ignore[assignment] # change integer values to a tuple of two of those values and change pad_width to index for k, v in pad_width.items(): @@ -1342,7 +1356,7 @@ def pad( if end_values is not None: pad_option_kwargs["end_values"] = end_values if reflect_type is not None: - pad_option_kwargs["reflect_type"] = reflect_type # type: ignore + pad_option_kwargs["reflect_type"] = reflect_type # type: ignore[assignment] array = duck_array_ops.pad( self.data.astype(dtype, copy=False), @@ -1372,7 +1386,7 @@ def _roll_one_dim(self, dim, count): # TODO: remove this once dask.array automatically aligns chunks data = data.rechunk(self.data.chunks) - return type(self)(self.dims, data, self._attrs, fastpath=True) + return self._replace(data=data) def roll(self, shifts=None, **shifts_kwargs): """ @@ -1400,7 +1414,11 @@ def roll(self, shifts=None, **shifts_kwargs): result = result._roll_one_dim(dim, count) return result - def transpose(self, *dims) -> "Variable": + def transpose( + self, + *dims, + missing_dims: str = "raise", + ) -> "Variable": """Return a new Variable object with transposed dimensions. Parameters @@ -1408,6 +1426,12 @@ def transpose(self, *dims) -> "Variable": *dims : str, optional By default, reverse the dimensions. Otherwise, reorder the dimensions to this order. + missing_dims : {"raise", "warn", "ignore"}, default: "raise" + What to do if dimensions that should be selected from are not present in the + Variable: + - "raise": raise an exception + - "warn": raise a warning, and ignore the missing dimensions + - "ignore": ignore the missing dimensions Returns ------- @@ -1426,15 +1450,17 @@ def transpose(self, *dims) -> "Variable": """ if len(dims) == 0: dims = self.dims[::-1] - dims = tuple(infix_dims(dims, self.dims)) - axes = self.get_axis_num(dims) + else: + dims = tuple(infix_dims(dims, self.dims, missing_dims)) + if len(dims) < 2 or dims == self.dims: # no need to transpose if only one dimension # or dims are in same order return self.copy(deep=False) + axes = self.get_axis_num(dims) data = as_indexable(self._data).transpose(axes) - return type(self)(dims, data, self._attrs, self._encoding, fastpath=True) + return self._replace(dims=dims, data=data) @property def T(self) -> "Variable": @@ -1466,8 +1492,8 @@ def set_dims(self, dims, shape=None): missing_dims = set(self.dims) - set(dims) if missing_dims: raise ValueError( - "new dimensions %r must be a superset of " - "existing dimensions %r" % (dims, self.dims) + f"new dimensions {dims!r} must be a superset of " + f"existing dimensions {self.dims!r}" ) self_dims = set(self.dims) @@ -1491,7 +1517,7 @@ def set_dims(self, dims, shape=None): def _stack_once(self, dims: List[Hashable], new_dim: Hashable): if not set(dims) <= set(self.dims): - raise ValueError("invalid existing dimensions: %s" % dims) + raise ValueError(f"invalid existing dimensions: {dims}") if new_dim in self.dims: raise ValueError( @@ -1535,7 +1561,7 @@ def stack(self, dimensions=None, **dimensions_kwargs): stacked : Variable Variable with the same attributes but stacked data. - See also + See Also -------- Variable.unstack """ @@ -1558,7 +1584,7 @@ def _unstack_once_full( new_dim_sizes = tuple(dims.values()) if old_dim not in self.dims: - raise ValueError("invalid existing dimension: %s" % old_dim) + raise ValueError(f"invalid existing dimension: {old_dim}") if set(new_dim_names).intersection(self.dims): raise ValueError( @@ -1601,7 +1627,7 @@ def _unstack_once( # Potentially we could replace `len(other_dims)` with just `-1` other_dims = [d for d in self.dims if d != dim] - new_shape = list(reordered.shape[: len(other_dims)]) + new_dim_sizes + new_shape = tuple(list(reordered.shape[: len(other_dims)]) + new_dim_sizes) new_dims = reordered.dims[: len(other_dims)] + new_dim_names if fill_value is dtypes.NA: @@ -1614,7 +1640,6 @@ def _unstack_once( else: dtype = self.dtype - # Currently fails on sparse due to https://github.com/pydata/sparse/issues/422 data = np.full_like( self.data, fill_value=fill_value, @@ -1625,6 +1650,8 @@ def _unstack_once( # Indexer is a list of lists of locations. Each list is the locations # on the new dimension. This is robust to the data being sparse; in that # case the destinations will be NaN / zero. + # sparse doesn't support item assigment, + # https://github.com/pydata/sparse/issues/114 data[(..., *indexer)] = reordered return self._replace(dims=new_dims, data=data) @@ -1655,7 +1682,7 @@ def unstack(self, dimensions=None, **dimensions_kwargs): unstacked : Variable Variable with the same attributes but unstacked data. - See also + See Also -------- Variable.stack DataArray.unstack @@ -1673,6 +1700,21 @@ def fillna(self, value): def where(self, cond, other=dtypes.NA): return ops.where_method(self, cond, other) + def clip(self, min=None, max=None): + """ + Return an array whose values are limited to ``[min, max]``. + At least one of max or min must be given. + + Refer to `numpy.clip` for full documentation. + + See Also + -------- + numpy.clip : equivalent function + """ + from .computation import apply_ufunc + + return apply_ufunc(np.clip, self, min, max, dask="allowed") + def reduce( self, func, @@ -1760,7 +1802,14 @@ def reduce( return Variable(dims, data, attrs=attrs) @classmethod - def concat(cls, variables, dim="concat_dim", positions=None, shortcut=False): + def concat( + cls, + variables, + dim="concat_dim", + positions=None, + shortcut=False, + combine_attrs="override", + ): """Concatenate variables along a new or existing dimension. Parameters @@ -1783,6 +1832,18 @@ def concat(cls, variables, dim="concat_dim", positions=None, shortcut=False): This option is used internally to speed-up groupby operations. If `shortcut` is True, some checks of internal consistency between arrays to concatenate are skipped. + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ + "override"}, default: "override" + String indicating how to combine attrs of the objects being merged: + + - "drop": empty attrs on returned Dataset. + - "identical": all attrs must be the same on every object. + - "no_conflicts": attrs from all objects are combined, any that have + the same name must also have the same value. + - "drop_conflicts": attrs from all objects are combined, any that have + the same name but different values are dropped. + - "override": skip comparing and copy attrs from the first dataset to + the result. Returns ------- @@ -1790,6 +1851,8 @@ def concat(cls, variables, dim="concat_dim", positions=None, shortcut=False): Concatenated Variable formed by stacking all the supplied variables along the given dimension. """ + from .merge import merge_attrs + if not isinstance(dim, str): (dim,) = dim.dims @@ -1814,7 +1877,9 @@ def concat(cls, variables, dim="concat_dim", positions=None, shortcut=False): dims = (dim,) + first_var.dims data = duck_array_ops.stack(arrays, axis=axis) - attrs = dict(first_var.attrs) + attrs = merge_attrs( + [var.attrs for var in variables], combine_attrs=combine_attrs + ) encoding = dict(first_var.encoding) if not shortcut: for var in variables: @@ -1900,7 +1965,6 @@ def quantile( * higher: ``j``. * nearest: ``i`` or ``j``, whichever is nearest. * midpoint: ``(i + j) / 2``. - keep_attrs : bool, optional If True, the variable's attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new @@ -1917,7 +1981,7 @@ def quantile( See Also -------- - numpy.nanquantile, pandas.Series.quantile, Dataset.quantile, + numpy.nanquantile, pandas.Series.quantile, Dataset.quantile DataArray.quantile """ @@ -1988,6 +2052,12 @@ def rank(self, dim, pct=False): -------- Dataset.rank, DataArray.rank """ + if not OPTIONS["use_bottleneck"]: + raise RuntimeError( + "rank requires bottleneck to be enabled." + " Call `xr.set_options(use_bottleneck=True)` to enable it." + ) + import bottleneck as bn data = self.data @@ -2027,7 +2097,7 @@ def rolling_window( For nd-rolling, should be list of integers. window_dim : str New name of the window dimension. - For nd-rolling, should be list of integers. + For nd-rolling, should be list of strings. center : bool, default: False If True, pad fill_value for both ends. Otherwise, pad in the head of the axis. @@ -2070,26 +2140,56 @@ def rolling_window( """ if fill_value is dtypes.NA: # np.nan is passed dtype, fill_value = dtypes.maybe_promote(self.dtype) - array = self.astype(dtype, copy=False).data + var = self.astype(dtype, copy=False) else: dtype = self.dtype - array = self.data + var = self - if isinstance(dim, list): - assert len(dim) == len(window) - assert len(dim) == len(window_dim) - assert len(dim) == len(center) - else: + if utils.is_scalar(dim): + for name, arg in zip( + ["window", "window_dim", "center"], [window, window_dim, center] + ): + if not utils.is_scalar(arg): + raise ValueError( + f"Expected {name}={arg!r} to be a scalar like 'dim'." + ) dim = [dim] - window = [window] - window_dim = [window_dim] - center = [center] + + # dim is now a list + nroll = len(dim) + if utils.is_scalar(window): + window = [window] * nroll + if utils.is_scalar(window_dim): + window_dim = [window_dim] * nroll + if utils.is_scalar(center): + center = [center] * nroll + if ( + len(dim) != len(window) + or len(dim) != len(window_dim) + or len(dim) != len(center) + ): + raise ValueError( + "'dim', 'window', 'window_dim', and 'center' must be the same length. " + f"Received dim={dim!r}, window={window!r}, window_dim={window_dim!r}," + f" and center={center!r}." + ) + + pads = {} + for d, win, cent in zip(dim, window, center): + if cent: + start = win // 2 # 10 -> 5, 9 -> 4 + end = win - 1 - start + pads[d] = (start, end) + else: + pads[d] = (win - 1, 0) + + padded = var.pad(pads, mode="constant", constant_values=fill_value) axis = [self.get_axis_num(d) for d in dim] new_dims = self.dims + tuple(window_dim) return Variable( new_dims, - duck_array_ops.rolling_window( - array, axis=axis, window=window, center=center, fill_value=fill_value + duck_array_ops.sliding_window_view( + padded.data, window_shape=window, axis=axis ), ) @@ -2100,18 +2200,19 @@ def coarsen( Apply reduction function. """ windows = {k: v for k, v in windows.items() if k in self.dims} - if not windows: - return self.copy() if keep_attrs is None: - keep_attrs = _get_keep_attrs(default=False) + keep_attrs = _get_keep_attrs(default=True) if keep_attrs: _attrs = self.attrs else: _attrs = None - reshaped, axes = self._coarsen_reshape(windows, boundary, side) + if not windows: + return self._replace(attrs=_attrs) + + reshaped, axes = self.coarsen_reshape(windows, boundary, side) if isinstance(func, str): name = func func = getattr(duck_array_ops, name, None) @@ -2120,7 +2221,7 @@ def coarsen( return self._replace(data=func(reshaped, axis=axes, **kwargs), attrs=_attrs) - def _coarsen_reshape(self, windows, boundary, side): + def coarsen_reshape(self, windows, boundary, side): """ Construct a reshaped-array for coarsen """ @@ -2136,7 +2237,9 @@ def _coarsen_reshape(self, windows, boundary, side): for d, window in windows.items(): if window <= 0: - raise ValueError(f"window must be > 0. Given {window}") + raise ValueError( + f"window must be > 0. Given {window} for dimension {d}" + ) variable = self for d, window in windows.items(): @@ -2146,8 +2249,8 @@ def _coarsen_reshape(self, windows, boundary, side): if boundary[d] == "exact": if n * window != size: raise ValueError( - "Could not coarsen a dimension of size {} with " - "window {}".format(size, window) + f"Could not coarsen a dimension of size {size} with " + f"window {window} and boundary='exact'. Try a different 'boundary' option." ) elif boundary[d] == "trim": if side[d] == "left": @@ -2255,64 +2358,50 @@ def notnull(self, keep_attrs: bool = None): @property def real(self): - return type(self)(self.dims, self.data.real, self._attrs) + return self._replace(data=self.data.real) @property def imag(self): - return type(self)(self.dims, self.data.imag, self._attrs) + return self._replace(data=self.data.imag) def __array_wrap__(self, obj, context=None): return Variable(self.dims, obj) - @staticmethod - def _unary_op(f): - @functools.wraps(f) - def func(self, *args, **kwargs): - keep_attrs = kwargs.pop("keep_attrs", None) - if keep_attrs is None: - keep_attrs = _get_keep_attrs(default=True) - with np.errstate(all="ignore"): - result = self.__array_wrap__(f(self.data, *args, **kwargs)) - if keep_attrs: - result.attrs = self.attrs - return result - - return func - - @staticmethod - def _binary_op(f, reflexive=False, **ignored_kwargs): - @functools.wraps(f) - def func(self, other): - if isinstance(other, (xr.DataArray, xr.Dataset)): - return NotImplemented - self_data, other_data, dims = _broadcast_compat_data(self, other) - keep_attrs = _get_keep_attrs(default=False) - attrs = self._attrs if keep_attrs else None - with np.errstate(all="ignore"): - new_data = ( - f(self_data, other_data) - if not reflexive - else f(other_data, self_data) - ) - result = Variable(dims, new_data, attrs=attrs) + def _unary_op(self, f, *args, **kwargs): + keep_attrs = kwargs.pop("keep_attrs", None) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + with np.errstate(all="ignore"): + result = self.__array_wrap__(f(self.data, *args, **kwargs)) + if keep_attrs: + result.attrs = self.attrs return result - return func - - @staticmethod - def _inplace_binary_op(f): - @functools.wraps(f) - def func(self, other): - if isinstance(other, xr.Dataset): - raise TypeError("cannot add a Dataset to a Variable in-place") + def _binary_op(self, other, f, reflexive=False): + if isinstance(other, (xr.DataArray, xr.Dataset)): + return NotImplemented + if reflexive and issubclass(type(self), type(other)): + other_data, self_data, dims = _broadcast_compat_data(other, self) + else: self_data, other_data, dims = _broadcast_compat_data(self, other) - if dims != self.dims: - raise ValueError("dimensions cannot change for in-place operations") - with np.errstate(all="ignore"): - self.values = f(self_data, other_data) - return self + keep_attrs = _get_keep_attrs(default=False) + attrs = self._attrs if keep_attrs else None + with np.errstate(all="ignore"): + new_data = ( + f(self_data, other_data) if not reflexive else f(other_data, self_data) + ) + result = Variable(dims, new_data, attrs=attrs) + return result - return func + def _inplace_binary_op(self, other, f): + if isinstance(other, xr.Dataset): + raise TypeError("cannot add a Dataset to a Variable in-place") + self_data, other_data, dims = _broadcast_compat_data(self, other) + if dims != self.dims: + raise ValueError("dimensions cannot change for in-place operations") + with np.errstate(all="ignore"): + self.values = f(self_data, other_data) + return self def _to_numeric(self, offset=None, datetime_unit=None, dtype=float): """A (private) method to convert datetime array to numeric dtype @@ -2432,7 +2521,7 @@ def argmin( ------- result : Variable or dict of Variable - See also + See Also -------- DataArray.argmin, DataArray.idxmin """ @@ -2477,16 +2566,13 @@ def argmax( ------- result : Variable or dict of Variable - See also + See Also -------- DataArray.argmax, DataArray.idxmax """ return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna) -ops.inject_all_ops_and_reduce_methods(Variable) - - class IndexVariable(Variable): """Wrapper for accommodating a pandas.Index in an xarray.Variable. @@ -2503,11 +2589,11 @@ class IndexVariable(Variable): def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): super().__init__(dims, data, attrs, encoding, fastpath) if self.ndim != 1: - raise ValueError("%s objects must be 1-dimensional" % type(self).__name__) + raise ValueError(f"{type(self).__name__} objects must be 1-dimensional") # Unlike in Variable, always eagerly load values into memory - if not isinstance(self._data, PandasIndexAdapter): - self._data = PandasIndexAdapter(self._data) + if not isinstance(self._data, PandasIndexingAdapter): + self._data = PandasIndexingAdapter(self._data) def __dask_tokenize__(self): from dask.base import normalize_token @@ -2520,14 +2606,14 @@ def load(self): return self # https://github.com/python/mypy/issues/1465 - @Variable.data.setter # type: ignore + @Variable.data.setter # type: ignore[attr-defined] def data(self, data): raise ValueError( f"Cannot assign to the .data attribute of dimension coordinate a.k.a IndexVariable {self.name!r}. " f"Please use DataArray.assign_coords, Dataset.assign_coords or Dataset.assign as appropriate." ) - @Variable.values.setter # type: ignore + @Variable.values.setter # type: ignore[attr-defined] def values(self, values): raise ValueError( f"Cannot assign to the .values attribute of dimension coordinate a.k.a IndexVariable {self.name!r}. " @@ -2551,18 +2637,27 @@ def _finalize_indexing_result(self, dims, data): # returns Variable rather than IndexVariable if multi-dimensional return Variable(dims, data, self._attrs, self._encoding) else: - return type(self)(dims, data, self._attrs, self._encoding, fastpath=True) + return self._replace(dims=dims, data=data) def __setitem__(self, key, value): - raise TypeError("%s values cannot be modified" % type(self).__name__) + raise TypeError(f"{type(self).__name__} values cannot be modified") @classmethod - def concat(cls, variables, dim="concat_dim", positions=None, shortcut=False): + def concat( + cls, + variables, + dim="concat_dim", + positions=None, + shortcut=False, + combine_attrs="override", + ): """Specialized version of Variable.concat for IndexVariable objects. This exists because we want to avoid converting Index objects to NumPy arrays, if possible. """ + from .merge import merge_attrs + if not isinstance(dim, str): (dim,) = dim.dims @@ -2589,12 +2684,13 @@ def concat(cls, variables, dim="concat_dim", positions=None, shortcut=False): # keep as str if possible as pandas.Index uses object (converts to numpy array) data = maybe_coerce_to_str(data, variables) - attrs = dict(first_var.attrs) + attrs = merge_attrs( + [var.attrs for var in variables], combine_attrs=combine_attrs + ) if not shortcut: for var in variables: if var.dims != first_var.dims: raise ValueError("inconsistent dimensions") - utils.remove_incompatible_items(attrs, var.attrs) return cls(first_var.dims, data, attrs) @@ -2632,7 +2728,7 @@ def copy(self, deep=True, data=None): data.shape, self.shape ) ) - return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True) + return self._replace(data=data) def equals(self, other, equiv=None): # if equiv is specified, super up @@ -2687,7 +2783,7 @@ def level_names(self): def get_level_variable(self, level): """Return a new IndexVariable from a given MultiIndex level.""" if self.level_names is None: - raise ValueError("IndexVariable %r has no MultiIndex" % self.name) + raise ValueError(f"IndexVariable {self.name!r} has no MultiIndex") index = self.to_index() return type(self)(self.dims, index.get_level_values(level)) @@ -2712,7 +2808,7 @@ def _unified_dims(variables): if len(set(var_dims)) < len(var_dims): raise ValueError( "broadcasting cannot handle duplicate " - "dimensions: %r" % list(var_dims) + f"dimensions: {list(var_dims)!r}" ) for d, s in zip(var_dims, var.shape): if d not in all_dims: @@ -2720,8 +2816,7 @@ def _unified_dims(variables): elif all_dims[d] != s: raise ValueError( "operands cannot be broadcast together " - "with mismatched lengths for dimension %r: %s" - % (d, (all_dims[d], s)) + f"with mismatched lengths for dimension {d!r}: {(all_dims[d], s)}" ) return all_dims @@ -2730,7 +2825,7 @@ def _broadcast_compat_variables(*variables): """Create broadcast compatible variables, with the same dimensions. Unlike the result of broadcast_variables(), some variables may have - dimensions of size 1 instead of the the size of the broadcast dimension. + dimensions of size 1 instead of the size of the broadcast dimension. """ dims = tuple(_unified_dims(variables)) return tuple(var.set_dims(dims) if var.dims != dims else var for var in variables) @@ -2768,7 +2863,13 @@ def _broadcast_compat_data(self, other): return self_data, other_data, dims -def concat(variables, dim="concat_dim", positions=None, shortcut=False): +def concat( + variables, + dim="concat_dim", + positions=None, + shortcut=False, + combine_attrs="override", +): """Concatenate variables along a new or existing dimension. Parameters @@ -2791,6 +2892,18 @@ def concat(variables, dim="concat_dim", positions=None, shortcut=False): This option is used internally to speed-up groupby operations. If `shortcut` is True, some checks of internal consistency between arrays to concatenate are skipped. + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ + "override"}, default: "override" + String indicating how to combine attrs of the objects being merged: + + - "drop": empty attrs on returned Dataset. + - "identical": all attrs must be the same on every object. + - "no_conflicts": attrs from all objects are combined, any that have + the same name must also have the same value. + - "drop_conflicts": attrs from all objects are combined, any that have + the same name but different values are dropped. + - "override": skip comparing and copy attrs from the first dataset to + the result. Returns ------- @@ -2800,9 +2913,9 @@ def concat(variables, dim="concat_dim", positions=None, shortcut=False): """ variables = list(variables) if all(isinstance(v, IndexVariable) for v in variables): - return IndexVariable.concat(variables, dim, positions, shortcut) + return IndexVariable.concat(variables, dim, positions, shortcut, combine_attrs) else: - return Variable.concat(variables, dim, positions, shortcut) + return Variable.concat(variables, dim, positions, shortcut, combine_attrs) def assert_unique_multiindex_level_names(variables): @@ -2815,7 +2928,7 @@ def assert_unique_multiindex_level_names(variables): level_names = defaultdict(list) all_level_names = set() for var_name, var in variables.items(): - if isinstance(var._data, PandasIndexAdapter): + if isinstance(var._data, PandasIndexingAdapter): idx_level_names = var.to_index_variable().level_names if idx_level_names is not None: for n in idx_level_names: @@ -2825,12 +2938,12 @@ def assert_unique_multiindex_level_names(variables): for k, v in level_names.items(): if k in variables: - v.append("(%s)" % k) + v.append(f"({k})") duplicate_names = [v for v in level_names.values() if len(v) > 1] if duplicate_names: conflict_str = "\n".join(", ".join(v) for v in duplicate_names) - raise ValueError("conflicting MultiIndex level name(s):\n%s" % conflict_str) + raise ValueError(f"conflicting MultiIndex level name(s):\n{conflict_str}") # Check confliction between level names and dimensions GH:2299 for k, v in variables.items(): for d in v.dims: diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 449a7200ee7..e8838b07157 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -119,6 +119,19 @@ def _weight_check(w): self.obj: T_DataWithCoords = obj self.weights: "DataArray" = weights + def _check_dim(self, dim: Optional[Union[Hashable, Iterable[Hashable]]]): + """raise an error if any dimension is missing""" + + if isinstance(dim, str) or not isinstance(dim, Iterable): + dims = [dim] if dim else [] + else: + dims = list(dim) + missing_dims = set(dims) - set(self.obj.dims) - set(self.weights.dims) + if missing_dims: + raise ValueError( + f"{self.__class__.__name__} does not contain the dimensions: {missing_dims}" + ) + @staticmethod def _reduce( da: "DataArray", @@ -146,7 +159,7 @@ def _reduce( def _sum_of_weights( self, da: "DataArray", dim: Optional[Union[Hashable, Iterable[Hashable]]] = None ) -> "DataArray": - """ Calculate the sum of weights, accounting for missing values """ + """Calculate the sum of weights, accounting for missing values""" # we need to mask data values that are nan; else the weights are wrong mask = da.notnull() @@ -236,6 +249,8 @@ def __repr__(self): class DataArrayWeighted(Weighted["DataArray"]): def _implementation(self, func, dim, **kwargs) -> "DataArray": + self._check_dim(dim) + dataset = self.obj._to_temp_dataset() dataset = dataset.map(func, dim=dim, **kwargs) return self.obj._from_temp_dataset(dataset) @@ -244,6 +259,8 @@ def _implementation(self, func, dim, **kwargs) -> "DataArray": class DatasetWeighted(Weighted["Dataset"]): def _implementation(self, func, dim, **kwargs) -> "Dataset": + self._check_dim(dim) + return self.obj.map(func, dim=dim, **kwargs) diff --git a/xarray/plot/__init__.py b/xarray/plot/__init__.py index 86a09506824..28ae0cf32e7 100644 --- a/xarray/plot/__init__.py +++ b/xarray/plot/__init__.py @@ -1,6 +1,6 @@ from .dataset_plot import scatter from .facetgrid import FacetGrid -from .plot import contour, contourf, hist, imshow, line, pcolormesh, plot, step +from .plot import contour, contourf, hist, imshow, line, pcolormesh, plot, step, surface __all__ = [ "plot", @@ -13,4 +13,5 @@ "pcolormesh", "FacetGrid", "scatter", + "surface", ] diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 6d942e1b0fa..c1aedd570bc 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -7,6 +7,7 @@ from .facetgrid import _easy_facetgrid from .utils import ( _add_colorbar, + _get_nice_quiver_magnitude, _is_numeric, _process_cmap_cbar_kwargs, get_axis, @@ -17,7 +18,7 @@ _MARKERSIZE_RANGE = np.array([18.0, 72.0]) -def _infer_meta_data(ds, x, y, hue, hue_style, add_guide): +def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname): dvars = set(ds.variables.keys()) error_msg = " must be one of ({:s})".format(", ".join(dvars)) @@ -48,11 +49,36 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide): add_colorbar = False add_legend = False else: - if add_guide is True: + if add_guide is True and funcname not in ("quiver", "streamplot"): raise ValueError("Cannot set add_guide when hue is None.") add_legend = False add_colorbar = False + if (add_guide or add_guide is None) and funcname == "quiver": + add_quiverkey = True + if hue: + add_colorbar = True + if not hue_style: + hue_style = "continuous" + elif hue_style != "continuous": + raise ValueError( + "hue_style must be 'continuous' or None for .plot.quiver or " + ".plot.streamplot" + ) + else: + add_quiverkey = False + + if (add_guide or add_guide is None) and funcname == "streamplot": + if hue: + add_colorbar = True + if not hue_style: + hue_style = "continuous" + elif hue_style != "continuous": + raise ValueError( + "hue_style must be 'continuous' or None for .plot.quiver or " + ".plot.streamplot" + ) + if hue_style is not None and hue_style not in ["discrete", "continuous"]: raise ValueError("hue_style must be either None, 'discrete' or 'continuous'.") @@ -66,6 +92,7 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide): return { "add_colorbar": add_colorbar, "add_legend": add_legend, + "add_quiverkey": add_quiverkey, "hue_label": hue_label, "hue_style": hue_style, "xlabel": label_from_attrs(ds[x]), @@ -169,40 +196,56 @@ def _dsplot(plotfunc): ds : Dataset x, y : str - Variable names for x, y axis. + Variable names for the *x* and *y* grid positions. + u, v : str, optional + Variable names for the *u* and *v* velocities + (in *x* and *y* direction, respectively; quiver/streamplot plots only). hue: str, optional - Variable by which to color scattered points - hue_style: str, optional - Can be either 'discrete' (legend) or 'continuous' (color bar). + Variable by which to color scatter points or arrows. + hue_style: {'continuous', 'discrete'}, optional + How to use the ``hue`` variable: + + - ``'continuous'`` -- continuous color scale + (default for numeric ``hue`` variables) + - ``'discrete'`` -- a color for each unique value, using the default color cycle + (default for non-numeric ``hue`` variables) markersize: str, optional - scatter only. Variable by which to vary size of scattered points. - size_norm: optional - Either None or 'Norm' instance to normalize the 'markersize' variable. - add_guide: bool, optional - Add a guide that depends on hue_style - - for "discrete", build a legend. - This is the default for non-numeric `hue` variables. - - for "continuous", build a colorbar + Variable by which to vary the size of scattered points (scatter plot only). + size_norm: matplotlib.colors.Normalize or tuple, optional + Used to normalize the ``markersize`` variable. + If a tuple is passed, the values will be passed to + :py:class:`matplotlib:matplotlib.colors.Normalize` as arguments. + Default: no normalization (``vmin=None``, ``vmax=None``, ``clip=False``). + scale: scalar, optional + Quiver only. Number of data units per arrow length unit. + Use this to control the length of the arrows: larger values lead to + smaller arrows. + add_guide: bool, optional, default: True + Add a guide that depends on ``hue_style``: + + - ``'continuous'`` -- build a colorbar + - ``'discrete'`` -- build a legend row : str, optional - If passed, make row faceted plots on this dimension name + If passed, make row faceted plots on this dimension name. col : str, optional - If passed, make column faceted plots on this dimension name + If passed, make column faceted plots on this dimension name. col_wrap : int, optional - Use together with ``col`` to wrap faceted plots + Use together with ``col`` to wrap faceted plots. ax : matplotlib axes object, optional - If None, uses the current axis. Not applicable when using facets. + If ``None``, use the current axes. Not applicable when using facets. subplot_kws : dict, optional - Dictionary of keyword arguments for matplotlib subplots. Only applies - to FacetGrid plotting. + Dictionary of keyword arguments for Matplotlib subplots + (see :py:meth:`matplotlib:matplotlib.figure.Figure.add_subplot`). + Only applies to FacetGrid plotting. aspect : scalar, optional - Aspect ratio of plot, so that ``aspect * size`` gives the width in + Aspect ratio of plot, so that ``aspect * size`` gives the *width* in inches. Only used if a ``size`` is provided. size : scalar, optional - If provided, create a new figure for the plot with the given size. - Height (in inches) of each plot. See also: ``aspect``. - norm : ``matplotlib.colors.Normalize`` instance, optional - If the ``norm`` has vmin or vmax specified, the corresponding kwarg - must be None. + If provided, create a new figure for the plot with the given size: + *height* (in inches) of each plot. See also: ``aspect``. + norm : matplotlib.colors.Normalize, optional + If ``norm`` has ``vmin`` or ``vmax`` specified, the corresponding + kwarg must be ``None``. vmin, vmax : float, optional Values to anchor the colormap, otherwise they are inferred from the data and other keyword arguments. When a diverging dataset is inferred, @@ -210,36 +253,40 @@ def _dsplot(plotfunc): ``center``. Setting both values prevents use of a diverging colormap. If discrete levels are provided as an explicit list, both of these values are ignored. - cmap : str or colormap, optional + cmap : matplotlib colormap name or colormap, optional The mapping from data values to color space. Either a - matplotlib colormap name or object. If not provided, this will - be either ``viridis`` (if the function infers a sequential - dataset) or ``RdBu_r`` (if the function infers a diverging - dataset). When `Seaborn` is installed, ``cmap`` may also be a - `seaborn` color palette. If ``cmap`` is seaborn color palette - and the plot type is not ``contour`` or ``contourf``, ``levels`` - must also be specified. - colors : color-like or list of color-like, optional - A single color or a list of colors. If the plot type is not ``contour`` - or ``contourf``, the ``levels`` argument is required. + Matplotlib colormap name or object. If not provided, this will + be either ``'viridis'`` (if the function infers a sequential + dataset) or ``'RdBu_r'`` (if the function infers a diverging + dataset). + See :doc:`Choosing Colormaps in Matplotlib ` + for more information. + + If *seaborn* is installed, ``cmap`` may also be a + `seaborn color palette `_. + Note: if ``cmap`` is a seaborn color palette, + ``levels`` must also be specified. + colors : str or array-like of color-like, optional + A single color or a list of colors. The ``levels`` argument + is required. center : float, optional The value at which to center the colormap. Passing this value implies use of a diverging colormap. Setting it to ``False`` prevents use of a diverging colormap. robust : bool, optional - If True and ``vmin`` or ``vmax`` are absent, the colormap range is + If ``True`` and ``vmin`` or ``vmax`` are absent, the colormap range is computed with 2nd and 98th percentiles instead of the extreme values. - extend : {"neither", "both", "min", "max"}, optional + extend : {'neither', 'both', 'min', 'max'}, optional How to draw arrows extending the colorbar beyond its limits. If not - provided, extend is inferred from vmin, vmax and the data limits. - levels : int or list-like object, optional - Split the colormap (cmap) into discrete color intervals. If an integer + provided, ``extend`` is inferred from ``vmin``, ``vmax`` and the data limits. + levels : int or array-like, optional + Split the colormap (``cmap``) into discrete color intervals. If an integer is provided, "nice" levels are chosen based on the data range: this can imply that the final number of levels is not exactly the expected one. Setting ``vmin`` and/or ``vmax`` with ``levels=N`` is equivalent to setting ``levels=np.linspace(vmin, vmax, N)``. **kwargs : optional - Additional keyword arguments to matplotlib + Additional keyword arguments to wrapped Matplotlib function. """ # Build on the original docstring @@ -250,6 +297,8 @@ def newplotfunc( ds, x=None, y=None, + u=None, + v=None, hue=None, hue_style=None, col=None, @@ -282,7 +331,9 @@ def newplotfunc( if _is_facetgrid: # facetgrid call meta_data = kwargs.pop("meta_data") else: - meta_data = _infer_meta_data(ds, x, y, hue, hue_style, add_guide) + meta_data = _infer_meta_data( + ds, x, y, hue, hue_style, add_guide, funcname=plotfunc.__name__ + ) hue_style = meta_data["hue_style"] @@ -317,13 +368,21 @@ def newplotfunc( else: cmap_params_subset = {} + if (u is not None or v is not None) and plotfunc.__name__ not in ( + "quiver", + "streamplot", + ): + raise ValueError("u, v are only allowed for quiver or streamplot plots.") + primitive = plotfunc( ds=ds, x=x, y=y, + ax=ax, + u=u, + v=v, hue=hue, hue_style=hue_style, - ax=ax, cmap_params=cmap_params_subset, **kwargs, ) @@ -344,6 +403,25 @@ def newplotfunc( cbar_kwargs["label"] = meta_data.get("hue_label", None) _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params) + if meta_data["add_quiverkey"]: + magnitude = _get_nice_quiver_magnitude(ds[u], ds[v]) + units = ds[u].attrs.get("units", "") + ax.quiverkey( + primitive, + X=0.85, + Y=0.9, + U=magnitude, + label=f"{magnitude}\n{units}", + labelpos="E", + coordinates="figure", + ) + + if plotfunc.__name__ in ("quiver", "streamplot"): + title = ds[u]._title_for_slice() + else: + title = ds[x]._title_for_slice() + ax.set_title(title) + return primitive @functools.wraps(newplotfunc) @@ -351,6 +429,8 @@ def plotmethod( _PlotMethods_obj, x=None, y=None, + u=None, + v=None, hue=None, hue_style=None, col=None, @@ -401,6 +481,8 @@ def plotmethod( def scatter(ds, x, y, ax, **kwargs): """ Scatter Dataset data variables against each other. + + Wraps :py:func:`matplotlib:matplotlib.pyplot.scatter`. """ if "add_colorbar" in kwargs or "add_legend" in kwargs: @@ -417,6 +499,10 @@ def scatter(ds, x, y, ax, **kwargs): size_norm = kwargs.pop("size_norm", None) size_mapping = kwargs.pop("size_mapping", None) # set by facetgrid + # Remove `u` and `v` so they don't get passed to `ax.scatter` + kwargs.pop("u", None) + kwargs.pop("v", None) + # need to infer size_mapping with full dataset data = _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping) @@ -450,3 +536,89 @@ def scatter(ds, x, y, ax, **kwargs): ) return primitive + + +@_dsplot +def quiver(ds, x, y, ax, u, v, **kwargs): + """Quiver plot of Dataset variables. + + Wraps :py:func:`matplotlib:matplotlib.pyplot.quiver`. + """ + import matplotlib as mpl + + if x is None or y is None or u is None or v is None: + raise ValueError("Must specify x, y, u, v for quiver plots.") + + x, y, u, v = broadcast(ds[x], ds[y], ds[u], ds[v]) + + args = [x.values, y.values, u.values, v.values] + hue = kwargs.pop("hue") + cmap_params = kwargs.pop("cmap_params") + + if hue: + args.append(ds[hue].values) + + # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params + if not cmap_params["norm"]: + cmap_params["norm"] = mpl.colors.Normalize( + cmap_params.pop("vmin"), cmap_params.pop("vmax") + ) + + kwargs.pop("hue_style") + kwargs.setdefault("pivot", "middle") + hdl = ax.quiver(*args, **kwargs, **cmap_params) + return hdl + + +@_dsplot +def streamplot(ds, x, y, ax, u, v, **kwargs): + """Plot streamlines of Dataset variables. + + Wraps :py:func:`matplotlib:matplotlib.pyplot.streamplot`. + """ + import matplotlib as mpl + + if x is None or y is None or u is None or v is None: + raise ValueError("Must specify x, y, u, v for streamplot plots.") + + # Matplotlib's streamplot has strong restrictions on what x and y can be, so need to + # get arrays transposed the 'right' way around. 'x' cannot vary within 'rows', so + # the dimension of x must be the second dimension. 'y' cannot vary with 'columns' so + # the dimension of y must be the first dimension. If x and y are both 2d, assume the + # user has got them right already. + if len(ds[x].dims) == 1: + xdim = ds[x].dims[0] + if len(ds[y].dims) == 1: + ydim = ds[y].dims[0] + if xdim is not None and ydim is None: + ydim = set(ds[y].dims) - set([xdim]) + if ydim is not None and xdim is None: + xdim = set(ds[x].dims) - set([ydim]) + + x, y, u, v = broadcast(ds[x], ds[y], ds[u], ds[v]) + + if xdim is not None and ydim is not None: + # Need to ensure the arrays are transposed correctly + x = x.transpose(ydim, xdim) + y = y.transpose(ydim, xdim) + u = u.transpose(ydim, xdim) + v = v.transpose(ydim, xdim) + + args = [x.values, y.values, u.values, v.values] + hue = kwargs.pop("hue") + cmap_params = kwargs.pop("cmap_params") + + if hue: + kwargs["color"] = ds[hue].values + + # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params + if not cmap_params["norm"]: + cmap_params["norm"] = mpl.colors.Normalize( + cmap_params.pop("vmin"), cmap_params.pop("vmax") + ) + + kwargs.pop("hue_style") + hdl = ax.streamplot(*args, **kwargs, **cmap_params) + + # Return .lines so colorbar creation works properly + return hdl.lines diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 937283e8230..8fff4d73697 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -7,6 +7,7 @@ from ..core.formatting import format_item from .utils import ( + _get_nice_quiver_magnitude, _infer_xy_labels, _process_cmap_cbar_kwargs, import_matplotlib_pyplot, @@ -35,13 +36,13 @@ def _nicetitle(coord, value, maxchar, template): class FacetGrid: """ - Initialize the matplotlib figure and FacetGrid object. + Initialize the Matplotlib figure and FacetGrid object. The :class:`FacetGrid` is an object that links a xarray DataArray to - a matplotlib figure with a particular structure. + a Matplotlib figure with a particular structure. In particular, :class:`FacetGrid` is used to draw plots with multiple - Axes where each Axes shows the same relationship conditioned on + axes, where each axes shows the same relationship conditioned on different levels of some dimension. It's possible to condition on up to two variables by assigning variables to the rows and columns of the grid. @@ -59,19 +60,19 @@ class FacetGrid: Attributes ---------- - axes : numpy object array - Contains axes in corresponding position, as returned from - plt.subplots - col_labels : list - list of :class:`matplotlib.text.Text` instances corresponding to column titles. - row_labels : list - list of :class:`matplotlib.text.Text` instances corresponding to row titles. - fig : matplotlib.Figure - The figure containing all the axes - name_dicts : numpy object array - Contains dictionaries mapping coordinate names to values. None is - used as a sentinel value for axes which should remain empty, ie. - sometimes the bottom right grid + axes : ndarray of matplotlib.axes.Axes + Array containing axes in corresponding position, as returned from + :py:func:`matplotlib.pyplot.subplots`. + col_labels : list of matplotlib.text.Text + Column titles. + row_labels : list of matplotlib.text.Text + Row titles. + fig : matplotlib.figure.Figure + The figure containing all the axes. + name_dicts : ndarray of dict + Array containing dictionaries mapping coordinate names to values. ``None`` is + used as a sentinel value for axes that should remain empty, i.e., + sometimes the rightmost grid positions in the bottom row. """ def __init__( @@ -91,26 +92,28 @@ def __init__( Parameters ---------- data : DataArray - xarray DataArray to be plotted - row, col : strings + xarray DataArray to be plotted. + row, col : str Dimesion names that define subsets of the data, which will be drawn on separate facets in the grid. col_wrap : int, optional - "Wrap" the column variable at this width, so that the column facets + "Wrap" the grid the for the column variable after this number of columns, + adding rows if ``col_wrap`` is less than the number of facets. sharex : bool, optional - If true, the facets will share x axes + If true, the facets will share *x* axes. sharey : bool, optional - If true, the facets will share y axes + If true, the facets will share *y* axes. figsize : tuple, optional A tuple (width, height) of the figure in inches. If set, overrides ``size`` and ``aspect``. aspect : scalar, optional Aspect ratio of each facet, so that ``aspect * size`` gives the - width of each facet in inches + width of each facet in inches. size : scalar, optional - Height (in inches) of each facet. See also: ``aspect`` + Height (in inches) of each facet. See also: ``aspect``. subplot_kws : dict, optional - Dictionary of keyword arguments for matplotlib subplots + Dictionary of keyword arguments for Matplotlib subplots + (:py:func:`matplotlib.pyplot.subplots`). """ @@ -213,7 +216,11 @@ def __init__( self.axes = axes self.row_names = row_names self.col_names = col_names + + # guides self.figlegend = None + self.quiverkey = None + self.cbar = None # Next the private variables self._single_group = single_group @@ -239,7 +246,7 @@ def _bottom_axes(self): return self.axes[-1, :] def _get_subset(self, key: Mapping, expected_ndim): - """ Index with "key" using either .loc or get_group as appropriate. """ + """Index with "key" using either .loc or get_group as appropriate.""" if self._groupby: result = self._obj[list(key.values())[0]] else: @@ -266,7 +273,7 @@ def map_dataarray(self, func, x, y, **kwargs): plotting method such as `xarray.plot.imshow` x, y : string Names of the coordinates to plot on x, y axes - kwargs : + **kwargs additional keyword arguments to func Returns @@ -289,7 +296,9 @@ def map_dataarray(self, func, x, y, **kwargs): if k not in {"cmap", "colors", "cbar_kwargs", "levels"} } func_kwargs.update(cmap_params) - func_kwargs.update({"add_colorbar": False, "add_labels": False}) + func_kwargs["add_colorbar"] = False + if func.__name__ != "surface": + func_kwargs["add_labels"] = False # Get x, y labels for the first subplot x, y = _infer_xy_labels( @@ -361,14 +370,15 @@ def map_dataset( from .dataset_plot import _infer_meta_data, _parse_size kwargs["add_guide"] = False - kwargs["_is_facetgrid"] = True if kwargs.get("markersize", None): kwargs["size_mapping"] = _parse_size( self.data[kwargs["markersize"]], kwargs.pop("size_norm", None) ) - meta_data = _infer_meta_data(self.data, x, y, hue, hue_style, add_guide) + meta_data = _infer_meta_data( + self.data, x, y, hue, hue_style, add_guide, funcname=func.__name__ + ) kwargs["meta_data"] = meta_data if hue and meta_data["hue_style"] == "continuous": @@ -378,6 +388,12 @@ def map_dataset( kwargs["meta_data"]["cmap_params"] = cmap_params kwargs["meta_data"]["cbar_kwargs"] = cbar_kwargs + kwargs["_is_facetgrid"] = True + + if func.__name__ == "quiver" and "scale" not in kwargs: + raise ValueError("Please provide scale.") + # TODO: come up with an algorithm for reasonable scale choice + for d, ax in zip(self.name_dicts.flat, self.axes.flat): # None is the sentinel value if d is not None: @@ -399,6 +415,9 @@ def map_dataset( elif meta_data["add_colorbar"]: self.add_colorbar(label=self._hue_label, **cbar_kwargs) + if meta_data["add_quiverkey"]: + self.add_quiverkey(kwargs["u"], kwargs["v"]) + return self def _finalize_grid(self, *axlabels): @@ -414,30 +433,22 @@ def _finalize_grid(self, *axlabels): self._finalized = True - def add_legend(self, **kwargs): - figlegend = self.fig.legend( - handles=self._mappables[-1], - labels=list(self._hue_var.values), - title=self._hue_label, - loc="center right", - **kwargs, - ) - - self.figlegend = figlegend + def _adjust_fig_for_guide(self, guide): # Draw the plot to set the bounding boxes correctly - self.fig.draw(self.fig.canvas.get_renderer()) + renderer = self.fig.canvas.get_renderer() + self.fig.draw(renderer) # Calculate and set the new width of the figure so the legend fits - legend_width = figlegend.get_window_extent().width / self.fig.dpi + guide_width = guide.get_window_extent(renderer).width / self.fig.dpi figure_width = self.fig.get_figwidth() - self.fig.set_figwidth(figure_width + legend_width) + self.fig.set_figwidth(figure_width + guide_width) # Draw the plot again to get the new transformations - self.fig.draw(self.fig.canvas.get_renderer()) + self.fig.draw(renderer) # Now calculate how much space we need on the right side - legend_width = figlegend.get_window_extent().width / self.fig.dpi - space_needed = legend_width / (figure_width + legend_width) + 0.02 + guide_width = guide.get_window_extent(renderer).width / self.fig.dpi + space_needed = guide_width / (figure_width + guide_width) + 0.02 # margin = .01 # _space_needed = margin + space_needed right = 1 - space_needed @@ -445,8 +456,18 @@ def add_legend(self, **kwargs): # Place the subplot axes to give space for the legend self.fig.subplots_adjust(right=right) + def add_legend(self, **kwargs): + self.figlegend = self.fig.legend( + handles=self._mappables[-1], + labels=list(self._hue_var.values), + title=self._hue_label, + loc="center right", + **kwargs, + ) + self._adjust_fig_for_guide(self.figlegend) + def add_colorbar(self, **kwargs): - """Draw a colorbar""" + """Draw a colorbar.""" kwargs = kwargs.copy() if self._cmap_extend is not None: kwargs.setdefault("extend", self._cmap_extend) @@ -460,6 +481,26 @@ def add_colorbar(self, **kwargs): ) return self + def add_quiverkey(self, u, v, **kwargs): + kwargs = kwargs.copy() + + magnitude = _get_nice_quiver_magnitude(self.data[u], self.data[v]) + units = self.data[u].attrs.get("units", "") + self.quiverkey = self.axes.flat[-1].quiverkey( + self._mappables[-1], + X=0.8, + Y=0.9, + U=magnitude, + label=f"{magnitude}\n{units}", + labelpos="E", + coordinates="figure", + ) + + # TODO: does not work because self.quiverkey.get_window_extent(renderer) = 0 + # https://github.com/matplotlib/matplotlib/issues/18530 + # self._adjust_fig_for_guide(self.quiverkey.text) + return self + def set_axis_labels(self, x_var=None, y_var=None): """Set axis labels on the left column and bottom row of the grid.""" if x_var is not None: @@ -504,7 +545,7 @@ def set_titles(self, template="{coord} = {value}", maxchar=30, size=None, **kwar Template for plot titles containing {coord} and {value} maxchar : int Truncate titles at maxchar - kwargs : keyword args + **kwargs : keyword args additional arguments to matplotlib.text Returns @@ -559,7 +600,7 @@ def set_titles(self, template="{coord} = {value}", maxchar=30, size=None, **kwar def set_ticks(self, max_xticks=_NTICKS, max_yticks=_NTICKS, fontsize=_FONTSIZE): """ - Set and control tick behavior + Set and control tick behavior. Parameters ---------- @@ -600,11 +641,11 @@ def map(self, func, *args, **kwargs): must plot to the currently active matplotlib Axes and take a `color` keyword argument. If faceting on the `hue` dimension, it must also take a `label` keyword argument. - args : strings + *args : strings Column names in self.data that identify variables with data to plot. The data for each variable is passed to `func` in the order the variables are specified in the call. - kwargs : keyword arguments + **kwargs : keyword arguments All keyword arguments are passed to the plotting function. Returns diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index c29b3926d0d..41c2153c27a 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -7,17 +7,22 @@ Dataset.plot._____ """ import functools +from distutils.version import LooseVersion import numpy as np import pandas as pd +from ..core.alignment import broadcast from .facetgrid import _easy_facetgrid from .utils import ( _add_colorbar, + _adjust_legend_subtitles, _assert_valid_xy, _ensure_plottable, _infer_interval_breaks, _infer_xy_labels, + _is_numeric, + _legend_add_subtitle, _process_cmap_cbar_kwargs, _rescale_imshow_rgb, _resolve_intervals_1dplot, @@ -26,8 +31,132 @@ get_axis, import_matplotlib_pyplot, label_from_attrs, + legend_elements, ) +# copied from seaborn +_MARKERSIZE_RANGE = np.array([18.0, 72.0]) + + +def _infer_scatter_metadata(darray, x, z, hue, hue_style, size): + def _determine_array(darray, name, array_style): + """Find and determine what type of array it is.""" + array = darray[name] + array_is_numeric = _is_numeric(array.values) + + if array_style is None: + array_style = "continuous" if array_is_numeric else "discrete" + elif array_style not in ["discrete", "continuous"]: + raise ValueError( + f"The style '{array_style}' is not valid, " + "valid options are None, 'discrete' or 'continuous'." + ) + + array_label = label_from_attrs(array) + + return array, array_style, array_label + + # Add nice looking labels: + out = dict(ylabel=label_from_attrs(darray)) + out.update( + { + k: label_from_attrs(darray[v]) if v in darray.coords else None + for k, v in [("xlabel", x), ("zlabel", z)] + } + ) + + # Add styles and labels for the dataarrays: + for type_, a, style in [("hue", hue, hue_style), ("size", size, None)]: + tp, stl, lbl = f"{type_}", f"{type_}_style", f"{type_}_label" + if a: + out[tp], out[stl], out[lbl] = _determine_array(darray, a, style) + else: + out[tp], out[stl], out[lbl] = None, None, None + + return out + + +# copied from seaborn +def _parse_size(data, norm, width): + """ + Determine what type of data it is. Then normalize it to width. + + If the data is categorical, normalize it to numbers. + """ + plt = import_matplotlib_pyplot() + + if data is None: + return None + + data = data.values.ravel() + + if not _is_numeric(data): + # Data is categorical. + # Use pd.unique instead of np.unique because that keeps + # the order of the labels: + levels = pd.unique(data) + numbers = np.arange(1, 1 + len(levels)) + else: + levels = numbers = np.sort(np.unique(data)) + + min_width, max_width = width + # width_range = min_width, max_width + + if norm is None: + norm = plt.Normalize() + elif isinstance(norm, tuple): + norm = plt.Normalize(*norm) + elif not isinstance(norm, plt.Normalize): + err = "``size_norm`` must be None, tuple, or Normalize object." + raise ValueError(err) + + norm.clip = True + if not norm.scaled(): + norm(np.asarray(numbers)) + # limits = norm.vmin, norm.vmax + + scl = norm(numbers) + widths = np.asarray(min_width + scl * (max_width - min_width)) + if scl.mask.any(): + widths[scl.mask] = 0 + sizes = dict(zip(levels, widths)) + + return pd.Series(sizes) + + +def _infer_scatter_data( + darray, x, z, hue, size, size_norm, size_mapping=None, size_range=(1, 10) +): + # Broadcast together all the chosen variables: + to_broadcast = dict(y=darray) + to_broadcast.update( + {k: darray[v] for k, v in dict(x=x, z=z).items() if v is not None} + ) + to_broadcast.update( + {k: darray[v] for k, v in dict(hue=hue, size=size).items() if v in darray.dims} + ) + broadcasted = dict(zip(to_broadcast.keys(), broadcast(*(to_broadcast.values())))) + + # Normalize hue and size and create lookup tables: + for type_, mapping, norm, width in [ + ("hue", None, None, [0, 1]), + ("size", size_mapping, size_norm, size_range), + ]: + broadcasted_type = broadcasted.get(type_, None) + if broadcasted_type is not None: + if mapping is None: + mapping = _parse_size(broadcasted_type, norm, width) + + broadcasted[type_] = broadcasted_type.copy( + data=np.reshape( + mapping.loc[broadcasted_type.values.ravel()].values, + broadcasted_type.shape, + ) + ) + broadcasted[f"{type_}_to_label"] = pd.Series(mapping.index, index=mapping) + + return broadcasted + def _infer_line_data(darray, x, y, hue): @@ -65,6 +194,8 @@ def _infer_line_data(darray, x, y, hue): raise ValueError("For 2D inputs, please specify either hue, x or y.") if y is None: + if hue is not None: + _assert_valid_xy(darray, hue, "hue") xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue) xplt = darray[xname] if xplt.ndim > 1: @@ -122,14 +253,14 @@ def plot( **kwargs, ): """ - Default plot of DataArray using matplotlib.pyplot. + Default plot of DataArray using :py:mod:`matplotlib:matplotlib.pyplot`. Calls xarray plotting function based on the dimensions of - darray.squeeze() + the squeezed DataArray. =============== =========================== Dimensions Plotting function - --------------- --------------------------- + =============== =========================== 1 :py:func:`xarray.plot.line` 2 :py:func:`xarray.plot.pcolormesh` Anything else :py:func:`xarray.plot.hist` @@ -139,23 +270,27 @@ def plot( ---------- darray : DataArray row : str, optional - If passed, make row faceted plots on this dimension name + If passed, make row faceted plots on this dimension name. col : str, optional - If passed, make column faceted plots on this dimension name + If passed, make column faceted plots on this dimension name. hue : str, optional - If passed, make faceted line plots with hue on this dimension name + If passed, make faceted line plots with hue on this dimension name. col_wrap : int, optional - Use together with ``col`` to wrap faceted plots - ax : matplotlib.axes.Axes, optional - If None, uses the current axis. Not applicable when using facets. + Use together with ``col`` to wrap faceted plots. + ax : matplotlib axes object, optional + If ``None``, use the current axes. Not applicable when using facets. rtol : float, optional Relative tolerance used to determine if the indexes are uniformly spaced. Usually a small positive number. subplot_kws : dict, optional - Dictionary of keyword arguments for matplotlib subplots. + Dictionary of keyword arguments for Matplotlib subplots + (see :py:meth:`matplotlib:matplotlib.figure.Figure.add_subplot`). **kwargs : optional - Additional keyword arguments to matplotlib + Additional keyword arguments for Matplotlib. + See Also + -------- + xarray.DataArray.squeeze """ from ..core.groupby import GroupBy @@ -248,48 +383,50 @@ def line( **kwargs, ): """ - Line plot of DataArray index against values + Line plot of DataArray values. - Wraps :func:`matplotlib:matplotlib.pyplot.plot` + Wraps :py:func:`matplotlib:matplotlib.pyplot.plot`. Parameters ---------- darray : DataArray - Must be 1 dimensional + Either 1D or 2D. If 2D, one of ``hue``, ``x`` or ``y`` must be provided. figsize : tuple, optional A tuple (width, height) of the figure in inches. Mutually exclusive with ``size`` and ``ax``. aspect : scalar, optional - Aspect ratio of plot, so that ``aspect * size`` gives the width in + Aspect ratio of plot, so that ``aspect * size`` gives the *width* in inches. Only used if a ``size`` is provided. size : scalar, optional - If provided, create a new figure for the plot with the given size. - Height (in inches) of each plot. See also: ``aspect``. + If provided, create a new figure for the plot with the given size: + *height* (in inches) of each plot. See also: ``aspect``. ax : matplotlib axes object, optional - Axis on which to plot this figure. By default, use the current axis. + Axes on which to plot. By default, the current is used. Mutually exclusive with ``size`` and ``figsize``. - hue : string, optional + hue : str, optional Dimension or coordinate for which you want multiple lines plotted. If plotting against a 2D coordinate, ``hue`` must be a dimension. - x, y : string, optional - Dimension, coordinate or MultiIndex level for x, y axis. + x, y : str, optional + Dimension, coordinate or multi-index level for *x*, *y* axis. Only one of these may be specified. - The other coordinate plots values from the DataArray on which this + The other will be used for values from the DataArray on which this plot method is called. - xscale, yscale : 'linear', 'symlog', 'log', 'logit', optional - Specifies scaling for the x- and y-axes respectively - xticks, yticks : Specify tick locations for x- and y-axes - xlim, ylim : Specify x- and y-axes limits + xscale, yscale : {'linear', 'symlog', 'log', 'logit'}, optional + Specifies scaling for the *x*- and *y*-axis, respectively. + xticks, yticks : array-like, optional + Specify tick locations for *x*- and *y*-axis. + xlim, ylim : array-like, optional + Specify *x*- and *y*-axis limits. xincrease : None, True, or False, optional - Should the values on the x axes be increasing from left to right? - if None, use the default for the matplotlib function. + Should the values on the *x* axis be increasing from left to right? + if ``None``, use the default for the Matplotlib function. yincrease : None, True, or False, optional - Should the values on the y axes be increasing from top to bottom? - if None, use the default for the matplotlib function. + Should the values on the *y* axis be increasing from top to bottom? + if ``None``, use the default for the Matplotlib function. add_legend : bool, optional - Add legend with y axis coordinates (2D inputs only). + Add legend with *y* axis coordinates (2D inputs only). *args, **kwargs : optional - Additional arguments to matplotlib.pyplot.plot + Additional arguments to :py:func:`matplotlib:matplotlib.pyplot.plot`. """ from ..core.groupby import GroupBy @@ -335,7 +472,7 @@ def line( # Remove pd.Intervals if contained in xplt.values and/or yplt.values. xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot( - xplt.values, yplt.values, kwargs + xplt.to_numpy(), yplt.to_numpy(), kwargs ) xlabel = label_from_attrs(xplt, extra=x_suffix) ylabel = label_from_attrs(yplt, extra=y_suffix) @@ -354,7 +491,7 @@ def line( ax.set_title(darray._title_for_slice()) if darray.ndim == 2 and add_legend: - ax.legend(handles=primitive, labels=list(hueplt.values), title=hue_label) + ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label) # Rotate dates on xlabels # Do this without calling autofmt_xdate so that x-axes ticks @@ -372,30 +509,29 @@ def line( def step(darray, *args, where="pre", drawstyle=None, ds=None, **kwargs): """ - Step plot of DataArray index against values + Step plot of DataArray values. - Similar to :func:`matplotlib:matplotlib.pyplot.step` + Similar to :py:func:`matplotlib:matplotlib.pyplot.step`. Parameters ---------- - where : {"pre", "post", "mid"}, default: "pre" + where : {'pre', 'post', 'mid'}, default: 'pre' Define where the steps should be placed: - - "pre": The y value is continued constantly to the left from + - ``'pre'``: The y value is continued constantly to the left from every *x* position, i.e. the interval ``(x[i-1], x[i]]`` has the value ``y[i]``. - - "post": The y value is continued constantly to the right from + - ``'post'``: The y value is continued constantly to the right from every *x* position, i.e. the interval ``[x[i], x[i+1])`` has the value ``y[i]``. - - "mid": Steps occur half-way between the *x* positions. + - ``'mid'``: Steps occur half-way between the *x* positions. Note that this parameter is ignored if one coordinate consists of - :py:func:`pandas.Interval` values, e.g. as a result of + :py:class:`pandas.Interval` values, e.g. as a result of :py:func:`xarray.Dataset.groupby_bins`. In this case, the actual boundaries of the interval are used. - *args, **kwargs : optional - Additional arguments following :py:func:`xarray.plot.line` + Additional arguments for :py:func:`xarray.plot.line`. """ if where not in {"pre", "post", "mid"}: raise ValueError("'where' argument to step must be 'pre', 'post' or 'mid'") @@ -429,35 +565,35 @@ def hist( **kwargs, ): """ - Histogram of DataArray + Histogram of DataArray. - Wraps :func:`matplotlib:matplotlib.pyplot.hist` + Wraps :py:func:`matplotlib:matplotlib.pyplot.hist`. - Plots N dimensional arrays by first flattening the array. + Plots *N*-dimensional arrays by first flattening the array. Parameters ---------- darray : DataArray - Can be any dimension + Can have any number of dimensions. figsize : tuple, optional A tuple (width, height) of the figure in inches. Mutually exclusive with ``size`` and ``ax``. aspect : scalar, optional - Aspect ratio of plot, so that ``aspect * size`` gives the width in + Aspect ratio of plot, so that ``aspect * size`` gives the *width* in inches. Only used if a ``size`` is provided. size : scalar, optional - If provided, create a new figure for the plot with the given size. - Height (in inches) of each plot. See also: ``aspect``. - ax : matplotlib.axes.Axes, optional - Axis on which to plot this figure. By default, use the current axis. + If provided, create a new figure for the plot with the given size: + *height* (in inches) of each plot. See also: ``aspect``. + ax : matplotlib axes object, optional + Axes on which to plot. By default, use the current axes. Mutually exclusive with ``size`` and ``figsize``. **kwargs : optional - Additional keyword arguments to matplotlib.pyplot.hist + Additional keyword arguments to :py:func:`matplotlib:matplotlib.pyplot.hist`. """ ax = get_axis(figsize, size, aspect, ax) - no_nan = np.ravel(darray.values) + no_nan = np.ravel(darray.to_numpy()) no_nan = no_nan[pd.notnull(no_nan)] primitive = ax.hist(no_nan, **kwargs) @@ -470,6 +606,291 @@ def hist( return primitive +def scatter( + darray, + *args, + row=None, + col=None, + figsize=None, + aspect=None, + size=None, + ax=None, + hue=None, + hue_style=None, + x=None, + z=None, + xincrease=None, + yincrease=None, + xscale=None, + yscale=None, + xticks=None, + yticks=None, + xlim=None, + ylim=None, + add_legend=None, + add_colorbar=None, + cbar_kwargs=None, + cbar_ax=None, + vmin=None, + vmax=None, + norm=None, + infer_intervals=None, + center=None, + levels=None, + robust=None, + colors=None, + extend=None, + cmap=None, + _labels=True, + **kwargs, +): + """ + Scatter plot a DataArray along some coordinates. + + Parameters + ---------- + darray : DataArray + Dataarray to plot. + x, y : str + Variable names for x, y axis. + hue: str, optional + Variable by which to color scattered points + hue_style: str, optional + Can be either 'discrete' (legend) or 'continuous' (color bar). + markersize: str, optional + scatter only. Variable by which to vary size of scattered points. + size_norm: optional + Either None or 'Norm' instance to normalize the 'markersize' variable. + add_guide: bool, optional + Add a guide that depends on hue_style + - for "discrete", build a legend. + This is the default for non-numeric `hue` variables. + - for "continuous", build a colorbar + row : str, optional + If passed, make row faceted plots on this dimension name + col : str, optional + If passed, make column faceted plots on this dimension name + col_wrap : int, optional + Use together with ``col`` to wrap faceted plots + ax : matplotlib axes object, optional + If None, uses the current axis. Not applicable when using facets. + subplot_kws : dict, optional + Dictionary of keyword arguments for matplotlib subplots. Only applies + to FacetGrid plotting. + aspect : scalar, optional + Aspect ratio of plot, so that ``aspect * size`` gives the width in + inches. Only used if a ``size`` is provided. + size : scalar, optional + If provided, create a new figure for the plot with the given size. + Height (in inches) of each plot. See also: ``aspect``. + norm : ``matplotlib.colors.Normalize`` instance, optional + If the ``norm`` has vmin or vmax specified, the corresponding kwarg + must be None. + vmin, vmax : float, optional + Values to anchor the colormap, otherwise they are inferred from the + data and other keyword arguments. When a diverging dataset is inferred, + setting one of these values will fix the other by symmetry around + ``center``. Setting both values prevents use of a diverging colormap. + If discrete levels are provided as an explicit list, both of these + values are ignored. + cmap : str or colormap, optional + The mapping from data values to color space. Either a + matplotlib colormap name or object. If not provided, this will + be either ``viridis`` (if the function infers a sequential + dataset) or ``RdBu_r`` (if the function infers a diverging + dataset). When `Seaborn` is installed, ``cmap`` may also be a + `seaborn` color palette. If ``cmap`` is seaborn color palette + and the plot type is not ``contour`` or ``contourf``, ``levels`` + must also be specified. + colors : color-like or list of color-like, optional + A single color or a list of colors. If the plot type is not ``contour`` + or ``contourf``, the ``levels`` argument is required. + center : float, optional + The value at which to center the colormap. Passing this value implies + use of a diverging colormap. Setting it to ``False`` prevents use of a + diverging colormap. + robust : bool, optional + If True and ``vmin`` or ``vmax`` are absent, the colormap range is + computed with 2nd and 98th percentiles instead of the extreme values. + extend : {"neither", "both", "min", "max"}, optional + How to draw arrows extending the colorbar beyond its limits. If not + provided, extend is inferred from vmin, vmax and the data limits. + levels : int or list-like object, optional + Split the colormap (cmap) into discrete color intervals. If an integer + is provided, "nice" levels are chosen based on the data range: this can + imply that the final number of levels is not exactly the expected one. + Setting ``vmin`` and/or ``vmax`` with ``levels=N`` is equivalent to + setting ``levels=np.linspace(vmin, vmax, N)``. + **kwargs : optional + Additional keyword arguments to matplotlib + """ + plt = import_matplotlib_pyplot() + + # Handle facetgrids first + if row or col: + allargs = locals().copy() + allargs.update(allargs.pop("kwargs")) + allargs.pop("darray") + subplot_kws = dict(projection="3d") if z is not None else None + return _easy_facetgrid( + darray, scatter, kind="dataarray", subplot_kws=subplot_kws, **allargs + ) + + # Further + _is_facetgrid = kwargs.pop("_is_facetgrid", False) + if _is_facetgrid: + # Why do I need to pop these here? + kwargs.pop("y", None) + kwargs.pop("args", None) + kwargs.pop("add_labels", None) + + _sizes = kwargs.pop("markersize", kwargs.pop("linewidth", None)) + size_norm = kwargs.pop("size_norm", None) + size_mapping = kwargs.pop("size_mapping", None) # set by facetgrid + cmap_params = kwargs.pop("cmap_params", None) + + figsize = kwargs.pop("figsize", None) + subplot_kws = dict() + if z is not None and ax is None: + # TODO: Importing Axes3D is not necessary in matplotlib >= 3.2. + # Remove when minimum requirement of matplotlib is 3.2: + from mpl_toolkits.mplot3d import Axes3D # type: ignore # noqa + + subplot_kws.update(projection="3d") + ax = get_axis(figsize, size, aspect, ax, **subplot_kws) + # Using 30, 30 minimizes rotation of the plot. Making it easier to + # build on your intuition from 2D plots: + if LooseVersion(plt.matplotlib.__version__) < "3.5.0": + ax.view_init(azim=30, elev=30) + else: + # https://github.com/matplotlib/matplotlib/pull/19873 + ax.view_init(azim=30, elev=30, vertical_axis="y") + else: + ax = get_axis(figsize, size, aspect, ax, **subplot_kws) + + _data = _infer_scatter_metadata(darray, x, z, hue, hue_style, _sizes) + + add_guide = kwargs.pop("add_guide", None) + if add_legend is not None: + pass + elif add_guide is None or add_guide is True: + add_legend = True if _data["hue_style"] == "discrete" else False + elif add_legend is None: + add_legend = False + + if add_colorbar is not None: + pass + elif add_guide is None or add_guide is True: + add_colorbar = True if _data["hue_style"] == "continuous" else False + else: + add_colorbar = False + + # need to infer size_mapping with full dataset + _data.update( + _infer_scatter_data( + darray, + x, + z, + hue, + _sizes, + size_norm, + size_mapping, + _MARKERSIZE_RANGE, + ) + ) + + cmap_params_subset = {} + if _data["hue"] is not None: + kwargs.update(c=_data["hue"].values.ravel()) + cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( + scatter, _data["hue"].values, **locals() + ) + + # subset that can be passed to scatter, hist2d + cmap_params_subset = { + vv: cmap_params[vv] for vv in ["vmin", "vmax", "norm", "cmap"] + } + + if _data["size"] is not None: + kwargs.update(s=_data["size"].values.ravel()) + + if LooseVersion(plt.matplotlib.__version__) < "3.5.0": + # Plot the data. 3d plots has the z value in upward direction + # instead of y. To make jumping between 2d and 3d easy and intuitive + # switch the order so that z is shown in the depthwise direction: + axis_order = ["x", "z", "y"] + else: + # Switching axis order not needed in 3.5.0, can also simplify the code + # that uses axis_order: + # https://github.com/matplotlib/matplotlib/pull/19873 + axis_order = ["x", "y", "z"] + + primitive = ax.scatter( + *[ + _data[v].values.ravel() + for v in axis_order + if _data.get(v, None) is not None + ], + **cmap_params_subset, + **kwargs, + ) + + # Set x, y, z labels: + i = 0 + set_label = [ax.set_xlabel, ax.set_ylabel, getattr(ax, "set_zlabel", None)] + for v in axis_order: + if _data.get(f"{v}label", None) is not None: + set_label[i](_data[f"{v}label"]) + i += 1 + + if add_legend: + + def to_label(data, key, x): + """Map prop values back to its original values.""" + if key in data: + # Use reindex to be less sensitive to float errors. + # Return as numpy array since legend_elements + # seems to require that: + return data[key].reindex(x, method="nearest").to_numpy() + else: + return x + + handles, labels = [], [] + for subtitle, prop, func in [ + ( + _data["hue_label"], + "colors", + functools.partial(to_label, _data, "hue_to_label"), + ), + ( + _data["size_label"], + "sizes", + functools.partial(to_label, _data, "size_to_label"), + ), + ]: + if subtitle: + # Get legend handles and labels that displays the + # values correctly. Order might be different because + # legend_elements uses np.unique instead of pd.unique, + # FacetGrid.add_legend might have troubles with this: + hdl, lbl = legend_elements(primitive, prop, num="auto", func=func) + hdl, lbl = _legend_add_subtitle(hdl, lbl, subtitle, ax.scatter) + handles += hdl + labels += lbl + legend = ax.legend(handles, labels, framealpha=0.5) + _adjust_legend_subtitles(legend) + + if add_colorbar and _data["hue_label"]: + if _data["hue_style"] == "discrete": + raise NotImplementedError("Cannot create a colorbar for non numerics.") + cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs + if "label" not in cbar_kwargs: + cbar_kwargs["label"] = _data["hue_label"] + _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params) + + return primitive + + # MUST run before any 2d plotting functions are defined since # _plot2d decorator adds them as methods here. class _PlotMethods: @@ -488,7 +909,7 @@ def __call__(self, **kwargs): # we can't use functools.wraps here since that also modifies the name / qualname __doc__ = __call__.__doc__ = plot.__doc__ - __call__.__wrapped__ = plot # type: ignore + __call__.__wrapped__ = plot # type: ignore[attr-defined] __call__.__annotations__ = plot.__annotations__ @functools.wraps(hist) @@ -503,6 +924,10 @@ def line(self, *args, **kwargs): def step(self, *args, **kwargs): return step(self._da, *args, **kwargs) + @functools.wraps(scatter) + def _scatter(self, *args, **kwargs): + return scatter(self._da, *args, **kwargs) + def override_signature(f): def wrapper(func): @@ -523,100 +948,108 @@ def _plot2d(plotfunc): Parameters ---------- darray : DataArray - Must be 2 dimensional, unless creating faceted plots - x : string, optional - Coordinate for x axis. If None use darray.dims[1] - y : string, optional - Coordinate for y axis. If None use darray.dims[0] + Must be two-dimensional, unless creating faceted plots. + x : str, optional + Coordinate for *x* axis. If ``None``, use ``darray.dims[1]``. + y : str, optional + Coordinate for *y* axis. If ``None``, use ``darray.dims[0]``. figsize : tuple, optional A tuple (width, height) of the figure in inches. Mutually exclusive with ``size`` and ``ax``. aspect : scalar, optional - Aspect ratio of plot, so that ``aspect * size`` gives the width in + Aspect ratio of plot, so that ``aspect * size`` gives the *width* in inches. Only used if a ``size`` is provided. size : scalar, optional - If provided, create a new figure for the plot with the given size. - Height (in inches) of each plot. See also: ``aspect``. + If provided, create a new figure for the plot with the given size: + *height* (in inches) of each plot. See also: ``aspect``. ax : matplotlib axes object, optional - Axis on which to plot this figure. By default, use the current axis. + Axes on which to plot. By default, use the current axes. Mutually exclusive with ``size`` and ``figsize``. row : string, optional - If passed, make row faceted plots on this dimension name + If passed, make row faceted plots on this dimension name. col : string, optional - If passed, make column faceted plots on this dimension name + If passed, make column faceted plots on this dimension name. col_wrap : int, optional - Use together with ``col`` to wrap faceted plots - xscale, yscale : 'linear', 'symlog', 'log', 'logit', optional - Specifies scaling for the x- and y-axes respectively - xticks, yticks : Specify tick locations for x- and y-axes - xlim, ylim : Specify x- and y-axes limits + Use together with ``col`` to wrap faceted plots. + xscale, yscale : {'linear', 'symlog', 'log', 'logit'}, optional + Specifies scaling for the *x*- and *y*-axis, respectively. + xticks, yticks : array-like, optional + Specify tick locations for *x*- and *y*-axis. + xlim, ylim : array-like, optional + Specify *x*- and *y*-axis limits. xincrease : None, True, or False, optional - Should the values on the x axes be increasing from left to right? - if None, use the default for the matplotlib function. + Should the values on the *x* axis be increasing from left to right? + If ``None``, use the default for the Matplotlib function. yincrease : None, True, or False, optional - Should the values on the y axes be increasing from top to bottom? - if None, use the default for the matplotlib function. + Should the values on the *y* axis be increasing from top to bottom? + If ``None``, use the default for the Matplotlib function. add_colorbar : bool, optional - Adds colorbar to axis + Add colorbar to axes. add_labels : bool, optional - Use xarray metadata to label axes - norm : ``matplotlib.colors.Normalize`` instance, optional - If the ``norm`` has vmin or vmax specified, the corresponding kwarg - must be None. - vmin, vmax : floats, optional + Use xarray metadata to label axes. + norm : matplotlib.colors.Normalize, optional + If ``norm`` has ``vmin`` or ``vmax`` specified, the corresponding + kwarg must be ``None``. + vmin, vmax : float, optional Values to anchor the colormap, otherwise they are inferred from the data and other keyword arguments. When a diverging dataset is inferred, setting one of these values will fix the other by symmetry around ``center``. Setting both values prevents use of a diverging colormap. If discrete levels are provided as an explicit list, both of these values are ignored. - cmap : matplotlib colormap name or object, optional + cmap : matplotlib colormap name or colormap, optional The mapping from data values to color space. If not provided, this - will be either be ``viridis`` (if the function infers a sequential - dataset) or ``RdBu_r`` (if the function infers a diverging dataset). - When `Seaborn` is installed, ``cmap`` may also be a `seaborn` - color palette. If ``cmap`` is seaborn color palette and the plot type - is not ``contour`` or ``contourf``, ``levels`` must also be specified. - colors : discrete colors to plot, optional - A single color or a list of colors. If the plot type is not ``contour`` - or ``contourf``, the ``levels`` argument is required. + will be either be ``'viridis'`` (if the function infers a sequential + dataset) or ``'RdBu_r'`` (if the function infers a diverging dataset). + See :doc:`Choosing Colormaps in Matplotlib ` + for more information. + + If *seaborn* is installed, ``cmap`` may also be a + `seaborn color palette `_. + Note: if ``cmap`` is a seaborn color palette and the plot type + is not ``'contour'`` or ``'contourf'``, ``levels`` must also be specified. + colors : str or array-like of color-like, optional + A single color or a sequence of colors. If the plot type is not ``'contour'`` + or ``'contourf'``, the ``levels`` argument is required. center : float, optional The value at which to center the colormap. Passing this value implies use of a diverging colormap. Setting it to ``False`` prevents use of a diverging colormap. robust : bool, optional - If True and ``vmin`` or ``vmax`` are absent, the colormap range is + If ``True`` and ``vmin`` or ``vmax`` are absent, the colormap range is computed with 2nd and 98th percentiles instead of the extreme values. - extend : {"neither", "both", "min", "max"}, optional + extend : {'neither', 'both', 'min', 'max'}, optional How to draw arrows extending the colorbar beyond its limits. If not - provided, extend is inferred from vmin, vmax and the data limits. - levels : int or list-like object, optional - Split the colormap (cmap) into discrete color intervals. If an integer + provided, ``extend`` is inferred from ``vmin``, ``vmax`` and the data limits. + levels : int or array-like, optional + Split the colormap (``cmap``) into discrete color intervals. If an integer is provided, "nice" levels are chosen based on the data range: this can imply that the final number of levels is not exactly the expected one. Setting ``vmin`` and/or ``vmax`` with ``levels=N`` is equivalent to setting ``levels=np.linspace(vmin, vmax, N)``. infer_intervals : bool, optional - Only applies to pcolormesh. If True, the coordinate intervals are - passed to pcolormesh. If False, the original coordinates are used + Only applies to pcolormesh. If ``True``, the coordinate intervals are + passed to pcolormesh. If ``False``, the original coordinates are used (this can be useful for certain map projections). The default is to always infer intervals, unless the mesh is irregular and plotted on a map projection. subplot_kws : dict, optional - Dictionary of keyword arguments for matplotlib subplots. Only used - for 2D and FacetGrid plots. - cbar_ax : matplotlib Axes, optional + Dictionary of keyword arguments for Matplotlib subplots. Only used + for 2D and faceted plots. + (see :py:meth:`matplotlib:matplotlib.figure.Figure.add_subplot`). + cbar_ax : matplotlib axes object, optional Axes in which to draw the colorbar. cbar_kwargs : dict, optional - Dictionary of keyword arguments to pass to the colorbar. + Dictionary of keyword arguments to pass to the colorbar + (see :meth:`matplotlib:matplotlib.figure.Figure.colorbar`). **kwargs : optional - Additional arguments to wrapped matplotlib function + Additional keyword arguments to wrapped Matplotlib function. Returns ------- artist : - The same type of primitive artist that the wrapped matplotlib - function returns + The same type of primitive artist that the wrapped Matplotlib + function returns. """ # Build on the original docstring @@ -674,7 +1107,11 @@ def newplotfunc( # Decide on a default for the colorbar before facetgrids if add_colorbar is None: - add_colorbar = plotfunc.__name__ != "contour" + add_colorbar = True + if plotfunc.__name__ == "contour" or ( + plotfunc.__name__ == "surface" and cmap is None + ): + add_colorbar = False imshow_rgb = plotfunc.__name__ == "imshow" and darray.ndim == ( 3 + (row is not None) + (col is not None) ) @@ -687,6 +1124,25 @@ def newplotfunc( darray = _rescale_imshow_rgb(darray, vmin, vmax, robust) vmin, vmax, robust = None, None, False + if subplot_kws is None: + subplot_kws = dict() + + if plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False): + if ax is None: + # TODO: Importing Axes3D is no longer necessary in matplotlib >= 3.2. + # Remove when minimum requirement of matplotlib is 3.2: + from mpl_toolkits.mplot3d import Axes3D # type: ignore # noqa: F401 + + # delete so it does not end up in locals() + del Axes3D + + # Need to create a "3d" Axes instance for surface plots + subplot_kws["projection"] = "3d" + + # In facet grids, shared axis labels don't make sense for surface plots + sharex = False + sharey = False + # Handle facetgrids first if row or col: allargs = locals().copy() @@ -705,6 +1161,19 @@ def newplotfunc( plt = import_matplotlib_pyplot() + if ( + plotfunc.__name__ == "surface" + and not kwargs.get("_is_facetgrid", False) + and ax is not None + ): + import mpl_toolkits # type: ignore + + if not isinstance(ax, mpl_toolkits.mplot3d.Axes3D): + raise ValueError( + "If ax is passed to surface(), it must be created with " + 'projection="3d"' + ) + rgb = kwargs.pop("rgb", None) if rgb is not None and plotfunc.__name__ != "imshow": raise ValueError('The "rgb" keyword is only valid for imshow()') @@ -718,28 +1187,22 @@ def newplotfunc( darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb ) - # better to pass the ndarrays directly to plotting functions - xval = darray[xlab].values - yval = darray[ylab].values - - # check if we need to broadcast one dimension - if xval.ndim < yval.ndim: - dims = darray[ylab].dims - if xval.shape[0] == yval.shape[0]: - xval = np.broadcast_to(xval[:, np.newaxis], yval.shape) - else: - xval = np.broadcast_to(xval[np.newaxis, :], yval.shape) + xval = darray[xlab] + yval = darray[ylab] - elif yval.ndim < xval.ndim: - dims = darray[xlab].dims - if yval.shape[0] == xval.shape[0]: - yval = np.broadcast_to(yval[:, np.newaxis], xval.shape) - else: - yval = np.broadcast_to(yval[np.newaxis, :], xval.shape) - elif xval.ndim == 2: - dims = darray[xlab].dims + if xval.ndim > 1 or yval.ndim > 1 or plotfunc.__name__ == "surface": + # Passing 2d coordinate values, need to ensure they are transposed the same + # way as darray. + # Also surface plots always need 2d coordinates + xval = xval.broadcast_like(darray) + yval = yval.broadcast_like(darray) + dims = darray.dims else: - dims = (darray[ylab].dims[0], darray[xlab].dims[0]) + dims = (yval.dims[0], xval.dims[0]) + + # better to pass the ndarrays directly to plotting functions + xval = xval.to_numpy() + yval = yval.to_numpy() # May need to transpose for correct x, y labels # xlab may be the name of a coord, we have to check for dim names @@ -783,13 +1246,13 @@ def newplotfunc( if "pcolormesh" == plotfunc.__name__: kwargs["infer_intervals"] = infer_intervals + kwargs["xscale"] = xscale + kwargs["yscale"] = yscale if "imshow" == plotfunc.__name__ and isinstance(aspect, str): # forbid usage of mpl strings raise ValueError("plt.imshow's `aspect` kwarg is not available in xarray") - if subplot_kws is None: - subplot_kws = dict() ax = get_axis(figsize, size, aspect, ax, **subplot_kws) primitive = plotfunc( @@ -809,6 +1272,8 @@ def newplotfunc( ax.set_xlabel(label_from_attrs(darray[xlab], xlab_extra)) ax.set_ylabel(label_from_attrs(darray[ylab], ylab_extra)) ax.set_title(darray._title_for_slice()) + if plotfunc.__name__ == "surface": + ax.set_zlabel(label_from_attrs(darray)) if add_colorbar: if add_labels and "label" not in cbar_kwargs: @@ -899,26 +1364,28 @@ def plotmethod( @_plot2d def imshow(x, y, z, ax, **kwargs): """ - Image plot of 2d DataArray using matplotlib.pyplot + Image plot of 2D DataArray. - Wraps :func:`matplotlib:matplotlib.pyplot.imshow` + Wraps :py:func:`matplotlib:matplotlib.pyplot.imshow`. While other plot methods require the DataArray to be strictly two-dimensional, ``imshow`` also accepts a 3D array where some dimension can be interpreted as RGB or RGBA color channels and allows this dimension to be specified via the kwarg ``rgb=``. - Unlike matplotlib, Xarray can apply ``vmin`` and ``vmax`` to RGB or RGBA - data, by applying a single scaling factor and offset to all bands. + Unlike :py:func:`matplotlib:matplotlib.pyplot.imshow`, which ignores ``vmin``/``vmax`` + for RGB(A) data, + xarray *will* use ``vmin`` and ``vmax`` for RGB(A) data + by applying a single scaling factor and offset to all bands. Passing ``robust=True`` infers ``vmin`` and ``vmax`` :ref:`in the usual way `. .. note:: This function needs uniformly spaced coordinates to - properly label the axes. Call DataArray.plot() to check. + properly label the axes. Call :py:meth:`DataArray.plot` to check. - The pixels are centered on the coordinates values. Ie, if the coordinate - value is 3.2 then the pixels for those coordinates will be centered on 3.2. + The pixels are centered on the coordinates. For example, if the coordinate + value is 3.2, then the pixels for those coordinates will be centered on 3.2. """ if x.ndim != 1 or y.ndim != 1: @@ -926,18 +1393,26 @@ def imshow(x, y, z, ax, **kwargs): "imshow requires 1D coordinates, try using pcolormesh or contour(f)" ) - # Centering the pixels- Assumes uniform spacing - try: - xstep = (x[1] - x[0]) / 2.0 - except IndexError: - # Arbitrary default value, similar to matplotlib behaviour - xstep = 0.1 - try: - ystep = (y[1] - y[0]) / 2.0 - except IndexError: - ystep = 0.1 - left, right = x[0] - xstep, x[-1] + xstep - bottom, top = y[-1] + ystep, y[0] - ystep + def _center_pixels(x): + """Center the pixels on the coordinates.""" + if np.issubdtype(x.dtype, str): + # When using strings as inputs imshow converts it to + # integers. Choose extent values which puts the indices in + # in the center of the pixels: + return 0 - 0.5, len(x) - 0.5 + + try: + # Center the pixels assuming uniform spacing: + xstep = 0.5 * (x[1] - x[0]) + except IndexError: + # Arbitrary default value, similar to matplotlib behaviour: + xstep = 0.1 + + return x[0] - xstep, x[-1] + xstep + + # Center the pixels: + left, right = _center_pixels(x) + top, bottom = _center_pixels(y) defaults = {"origin": "upper", "interpolation": "nearest"} @@ -968,15 +1443,22 @@ def imshow(x, y, z, ax, **kwargs): primitive = ax.imshow(z, **defaults) + # If x or y are strings the ticklabels have been replaced with + # integer indices. Replace them back to strings: + for axis, v in [("x", x), ("y", y)]: + if np.issubdtype(v.dtype, str): + getattr(ax, f"set_{axis}ticks")(np.arange(len(v))) + getattr(ax, f"set_{axis}ticklabels")(v) + return primitive @_plot2d def contour(x, y, z, ax, **kwargs): """ - Contour plot of 2d DataArray + Contour plot of 2D DataArray. - Wraps :func:`matplotlib:matplotlib.pyplot.contour` + Wraps :py:func:`matplotlib:matplotlib.pyplot.contour`. """ primitive = ax.contour(x, y, z, **kwargs) return primitive @@ -985,20 +1467,20 @@ def contour(x, y, z, ax, **kwargs): @_plot2d def contourf(x, y, z, ax, **kwargs): """ - Filled contour plot of 2d DataArray + Filled contour plot of 2D DataArray. - Wraps :func:`matplotlib:matplotlib.pyplot.contourf` + Wraps :py:func:`matplotlib:matplotlib.pyplot.contourf`. """ primitive = ax.contourf(x, y, z, **kwargs) return primitive @_plot2d -def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs): +def pcolormesh(x, y, z, ax, xscale=None, yscale=None, infer_intervals=None, **kwargs): """ - Pseudocolor plot of 2d DataArray + Pseudocolor plot of 2D DataArray. - Wraps :func:`matplotlib:matplotlib.pyplot.pcolormesh` + Wraps :py:func:`matplotlib:matplotlib.pyplot.pcolormesh`. """ # decide on a default for infer_intervals (GH781) @@ -1012,24 +1494,32 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs): else: infer_intervals = True - if infer_intervals and ( - (np.shape(x)[0] == np.shape(z)[1]) - or ((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1])) + if ( + infer_intervals + and not np.issubdtype(x.dtype, str) + and ( + (np.shape(x)[0] == np.shape(z)[1]) + or ((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1])) + ) ): if len(x.shape) == 1: - x = _infer_interval_breaks(x, check_monotonic=True) + x = _infer_interval_breaks(x, check_monotonic=True, scale=xscale) else: # we have to infer the intervals on both axes - x = _infer_interval_breaks(x, axis=1) - x = _infer_interval_breaks(x, axis=0) + x = _infer_interval_breaks(x, axis=1, scale=xscale) + x = _infer_interval_breaks(x, axis=0, scale=xscale) - if infer_intervals and (np.shape(y)[0] == np.shape(z)[0]): + if ( + infer_intervals + and not np.issubdtype(y.dtype, str) + and (np.shape(y)[0] == np.shape(z)[0]) + ): if len(y.shape) == 1: - y = _infer_interval_breaks(y, check_monotonic=True) + y = _infer_interval_breaks(y, check_monotonic=True, scale=yscale) else: # we have to infer the intervals on both axes - y = _infer_interval_breaks(y, axis=1) - y = _infer_interval_breaks(y, axis=0) + y = _infer_interval_breaks(y, axis=1, scale=yscale) + y = _infer_interval_breaks(y, axis=0, scale=yscale) primitive = ax.pcolormesh(x, y, z, **kwargs) @@ -1041,3 +1531,14 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs): ax.set_ylim(y[0], y[-1]) return primitive + + +@_plot2d +def surface(x, y, z, ax, **kwargs): + """ + Surface plot of 2D DataArray. + + Wraps :py:meth:`matplotlib:mpl_toolkits.mplot3d.axes3d.Axes3D.plot_surface`. + """ + primitive = ax.plot_surface(x, y, z, **kwargs) + return primitive diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 601b23a3065..f2f296096a5 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -9,6 +9,7 @@ import pandas as pd from ..core.options import OPTIONS +from ..core.pycompat import DuckArrayModule from ..core.utils import is_scalar try: @@ -44,14 +45,13 @@ def _determine_extend(calc_data, vmin, vmax): extend_min = calc_data.min() < vmin extend_max = calc_data.max() > vmax if extend_min and extend_max: - extend = "both" + return "both" elif extend_min: - extend = "min" + return "min" elif extend_max: - extend = "max" + return "max" else: - extend = "neither" - return extend + return "neither" def _build_discrete_cmap(cmap, levels, extend, filled): @@ -60,6 +60,9 @@ def _build_discrete_cmap(cmap, levels, extend, filled): """ import matplotlib as mpl + if len(levels) == 1: + levels = [levels[0], levels[0]] + if not filled: # non-filled contour plots extend = "max" @@ -159,12 +162,12 @@ def _determine_cmap_params( Use some heuristics to set good defaults for colorbar and range. Parameters - ========== - plot_data: Numpy array + ---------- + plot_data : Numpy array Doesn't handle xarray objects Returns - ======= + ------- cmap_params : dict Use depends on the type of the plotting function """ @@ -317,7 +320,7 @@ def _infer_xy_labels_3d(darray, x, y, rgb): if len(set(not_none)) < len(not_none): raise ValueError( "Dimension names must be None or unique strings, but imshow was " - "passed x=%r, y=%r, and rgb=%r." % (x, y, rgb) + f"passed x={x!r}, y={y!r}, and rgb={rgb!r}." ) for label in not_none: if label not in darray.dims: @@ -339,8 +342,7 @@ def _infer_xy_labels_3d(darray, x, y, rgb): rgb = could_be_color[0] if rgb is not None and darray[rgb].size not in (3, 4): raise ValueError( - "Cannot interpret dim %r of size %s as RGB or RGBA." - % (rgb, darray[rgb].size) + f"Cannot interpret dim {rgb!r} of size {darray[rgb].size} as RGB or RGBA." ) # If rgb dimension is still unknown, there must be two or three dimensions @@ -350,9 +352,9 @@ def _infer_xy_labels_3d(darray, x, y, rgb): rgb = could_be_color[-1] warnings.warn( "Several dimensions of this array could be colors. Xarray " - "will use the last possible dimension (%r) to match " + f"will use the last possible dimension ({rgb!r}) to match " "matplotlib.pyplot.imshow. You can pass names of x, y, " - "and/or rgb dimensions to override this guess." % rgb + "and/or rgb dimensions to override this guess." ) assert rgb is not None @@ -440,11 +442,26 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs): raise ValueError("cannot use subplot_kws with existing ax") if ax is None: - ax = plt.gca(**kwargs) + ax = _maybe_gca(**kwargs) return ax +def _maybe_gca(**kwargs): + + import matplotlib.pyplot as plt + + # can call gcf unconditionally: either it exists or would be created by plt.axes + f = plt.gcf() + + # only call gca if an active axes exists + if f.axes: + # can not pass kwargs to active axes + return plt.gca() + + return plt.axes(**kwargs) + + def label_from_attrs(da, extra=""): """Makes informative labels if variable metadata (attrs) follows CF conventions.""" @@ -458,12 +475,20 @@ def label_from_attrs(da, extra=""): else: name = "" - if da.attrs.get("units"): - units = " [{}]".format(da.attrs["units"]) - elif da.attrs.get("unit"): - units = " [{}]".format(da.attrs["unit"]) + def _get_units_from_attrs(da): + if da.attrs.get("units"): + units = " [{}]".format(da.attrs["units"]) + elif da.attrs.get("unit"): + units = " [{}]".format(da.attrs["unit"]) + else: + units = "" + return units + + pint_array_type = DuckArrayModule("pint").type + if isinstance(da.data, pint_array_type): + units = " [{}]".format(str(da.data.units)) else: - units = "" + units = _get_units_from_attrs(da) return "\n".join(textwrap.wrap(name + extra + units, 30)) @@ -588,7 +613,14 @@ def _ensure_plottable(*args): Raise exception if there is anything in args that can't be plotted on an axis by matplotlib. """ - numpy_types = [np.floating, np.integer, np.timedelta64, np.datetime64, np.bool_] + numpy_types = [ + np.floating, + np.integer, + np.timedelta64, + np.datetime64, + np.bool_, + np.str_, + ] other_types = [datetime] try: import cftime @@ -659,15 +691,15 @@ def _rescale_imshow_rgb(darray, vmin, vmax, robust): vmax = 255 if np.issubdtype(darray.dtype, np.integer) else 1 if vmax < vmin: raise ValueError( - "vmin=%r is less than the default vmax (%r) - you must supply " - "a vmax > vmin in this case." % (vmin, vmax) + f"vmin={vmin!r} is less than the default vmax ({vmax!r}) - you must supply " + "a vmax > vmin in this case." ) elif vmin is None: vmin = 0 if vmin > vmax: raise ValueError( - "vmax=%r is less than the default vmin (0) - you must supply " - "a vmin < vmax in this case." % vmax + f"vmax={vmax!r} is less than the default vmin (0) - you must supply " + "a vmin < vmax in this case." ) # Scale interval [vmin .. vmax] to [0 .. 1], with darray as 64-bit float # to avoid precision loss, integer over/underflow, etc with extreme inputs. @@ -748,13 +780,16 @@ def _is_monotonic(coord, axis=0): return np.all(delta_pos) or np.all(delta_neg) -def _infer_interval_breaks(coord, axis=0, check_monotonic=False): +def _infer_interval_breaks(coord, axis=0, scale=None, check_monotonic=False): """ >>> _infer_interval_breaks(np.arange(5)) array([-0.5, 0.5, 1.5, 2.5, 3.5, 4.5]) >>> _infer_interval_breaks([[0, 1], [3, 4]], axis=1) array([[-0.5, 0.5, 1.5], [ 2.5, 3.5, 4.5]]) + >>> _infer_interval_breaks(np.logspace(-2, 2, 5), scale="log") + array([3.16227766e-03, 3.16227766e-02, 3.16227766e-01, 3.16227766e+00, + 3.16227766e+01, 3.16227766e+02]) """ coord = np.asarray(coord) @@ -768,6 +803,15 @@ def _infer_interval_breaks(coord, axis=0, check_monotonic=False): "the `seaborn` statistical plotting library." % axis ) + # If logscale, compute the intervals in the logarithmic space + if scale == "log": + if (coord <= 0).any(): + raise ValueError( + "Found negative or zero value in coordinates. " + + "Coordinates must be positive on logscale plots." + ) + coord = np.log10(coord) + deltas = 0.5 * np.diff(coord, axis=axis) if deltas.size == 0: deltas = np.array(0.0) @@ -776,7 +820,13 @@ def _infer_interval_breaks(coord, axis=0, check_monotonic=False): trim_last = tuple( slice(None, -1) if n == axis else slice(None) for n in range(coord.ndim) ) - return np.concatenate([first, coord[trim_last] + deltas, last], axis=axis) + interval_breaks = np.concatenate( + [first, coord[trim_last] + deltas, last], axis=axis + ) + if scale == "log": + # Recovert the intervals into the linear space + return np.power(10, interval_breaks) + return interval_breaks def _process_cmap_cbar_kwargs( @@ -791,17 +841,24 @@ def _process_cmap_cbar_kwargs( ): """ Parameters - ========== + ---------- func : plotting function data : ndarray, Data values Returns - ======= + ------- cmap_params - cbar_kwargs """ + if func.__name__ == "surface": + # Leave user to specify cmap settings for surface plots + kwargs["cmap"] = cmap + return { + k: kwargs.get(k, None) + for k in ["vmin", "vmax", "cmap", "extend", "levels", "norm"] + }, {} + cbar_kwargs = {} if cbar_kwargs is None else dict(cbar_kwargs) if "contour" in func.__name__ and levels is None: @@ -842,3 +899,240 @@ def _process_cmap_cbar_kwargs( } return cmap_params, cbar_kwargs + + +def _get_nice_quiver_magnitude(u, v): + import matplotlib as mpl + + ticker = mpl.ticker.MaxNLocator(3) + mean = np.mean(np.hypot(u.to_numpy(), v.to_numpy())) + magnitude = ticker.tick_values(0, mean)[-2] + return magnitude + + +# Copied from matplotlib, tweaked so func can return strings. +# https://github.com/matplotlib/matplotlib/issues/19555 +def legend_elements( + self, prop="colors", num="auto", fmt=None, func=lambda x: x, **kwargs +): + """ + Create legend handles and labels for a PathCollection. + + Each legend handle is a `.Line2D` representing the Path that was drawn, + and each label is a string what each Path represents. + + This is useful for obtaining a legend for a `~.Axes.scatter` plot; + e.g.:: + + scatter = plt.scatter([1, 2, 3], [4, 5, 6], c=[7, 2, 3]) + plt.legend(*scatter.legend_elements()) + + creates three legend elements, one for each color with the numerical + values passed to *c* as the labels. + + Also see the :ref:`automatedlegendcreation` example. + + + Parameters + ---------- + prop : {"colors", "sizes"}, default: "colors" + If "colors", the legend handles will show the different colors of + the collection. If "sizes", the legend will show the different + sizes. To set both, use *kwargs* to directly edit the `.Line2D` + properties. + num : int, None, "auto" (default), array-like, or `~.ticker.Locator` + Target number of elements to create. + If None, use all unique elements of the mappable array. If an + integer, target to use *num* elements in the normed range. + If *"auto"*, try to determine which option better suits the nature + of the data. + The number of created elements may slightly deviate from *num* due + to a `~.ticker.Locator` being used to find useful locations. + If a list or array, use exactly those elements for the legend. + Finally, a `~.ticker.Locator` can be provided. + fmt : str, `~matplotlib.ticker.Formatter`, or None (default) + The format or formatter to use for the labels. If a string must be + a valid input for a `~.StrMethodFormatter`. If None (the default), + use a `~.ScalarFormatter`. + func : function, default: ``lambda x: x`` + Function to calculate the labels. Often the size (or color) + argument to `~.Axes.scatter` will have been pre-processed by the + user using a function ``s = f(x)`` to make the markers visible; + e.g. ``size = np.log10(x)``. Providing the inverse of this + function here allows that pre-processing to be inverted, so that + the legend labels have the correct values; e.g. ``func = lambda + x: 10**x``. + **kwargs + Allowed keyword arguments are *color* and *size*. E.g. it may be + useful to set the color of the markers if *prop="sizes"* is used; + similarly to set the size of the markers if *prop="colors"* is + used. Any further parameters are passed onto the `.Line2D` + instance. This may be useful to e.g. specify a different + *markeredgecolor* or *alpha* for the legend handles. + + Returns + ------- + handles : list of `.Line2D` + Visual representation of each element of the legend. + labels : list of str + The string labels for elements of the legend. + """ + import warnings + + import matplotlib as mpl + + mlines = mpl.lines + + handles = [] + labels = [] + + if prop == "colors": + arr = self.get_array() + if arr is None: + warnings.warn( + "Collection without array used. Make sure to " + "specify the values to be colormapped via the " + "`c` argument." + ) + return handles, labels + _size = kwargs.pop("size", mpl.rcParams["lines.markersize"]) + + def _get_color_and_size(value): + return self.cmap(self.norm(value)), _size + + elif prop == "sizes": + arr = self.get_sizes() + _color = kwargs.pop("color", "k") + + def _get_color_and_size(value): + return _color, np.sqrt(value) + + else: + raise ValueError( + "Valid values for `prop` are 'colors' or " + f"'sizes'. You supplied '{prop}' instead." + ) + + # Get the unique values and their labels: + values = np.unique(arr) + label_values = np.asarray(func(values)) + label_values_are_numeric = np.issubdtype(label_values.dtype, np.number) + + # Handle the label format: + if fmt is None and label_values_are_numeric: + fmt = mpl.ticker.ScalarFormatter(useOffset=False, useMathText=True) + elif fmt is None and not label_values_are_numeric: + fmt = mpl.ticker.StrMethodFormatter("{x}") + elif isinstance(fmt, str): + fmt = mpl.ticker.StrMethodFormatter(fmt) + fmt.create_dummy_axis() + + if num == "auto": + num = 9 + if len(values) <= num: + num = None + + if label_values_are_numeric: + label_values_min = label_values.min() + label_values_max = label_values.max() + fmt.set_bounds(label_values_min, label_values_max) + + if num is not None: + # Labels are numerical but larger than the target + # number of elements, reduce to target using matplotlibs + # ticker classes: + if isinstance(num, mpl.ticker.Locator): + loc = num + elif np.iterable(num): + loc = mpl.ticker.FixedLocator(num) + else: + num = int(num) + loc = mpl.ticker.MaxNLocator( + nbins=num, min_n_ticks=num - 1, steps=[1, 2, 2.5, 3, 5, 6, 8, 10] + ) + + # Get nicely spaced label_values: + label_values = loc.tick_values(label_values_min, label_values_max) + + # Remove extrapolated label_values: + cond = (label_values >= label_values_min) & ( + label_values <= label_values_max + ) + label_values = label_values[cond] + + # Get the corresponding values by creating a linear interpolant + # with small step size: + values_interp = np.linspace(values.min(), values.max(), 256) + label_values_interp = func(values_interp) + ix = np.argsort(label_values_interp) + values = np.interp(label_values, label_values_interp[ix], values_interp[ix]) + elif num is not None and not label_values_are_numeric: + # Labels are not numerical so modifying label_values is not + # possible, instead filter the array with nicely distributed + # indexes: + if type(num) == int: + loc = mpl.ticker.LinearLocator(num) + else: + raise ValueError("`num` only supports integers for non-numeric labels.") + + ind = loc.tick_values(0, len(label_values) - 1).astype(int) + label_values = label_values[ind] + values = values[ind] + + # Some formatters requires set_locs: + if hasattr(fmt, "set_locs"): + fmt.set_locs(label_values) + + # Default settings for handles, add or override with kwargs: + kw = dict(markeredgewidth=self.get_linewidths()[0], alpha=self.get_alpha()) + kw.update(kwargs) + + for val, lab in zip(values, label_values): + color, size = _get_color_and_size(val) + h = mlines.Line2D( + [0], [0], ls="", color=color, ms=size, marker=self.get_paths()[0], **kw + ) + handles.append(h) + labels.append(fmt(lab)) + + return handles, labels + + +def _legend_add_subtitle(handles, labels, text, func): + """Add a subtitle to legend handles.""" + if text and len(handles) > 1: + # Create a blank handle that's not visible, the + # invisibillity will be used to discern which are subtitles + # or not: + blank_handle = func([], [], label=text) + blank_handle.set_visible(False) + + # Subtitles are shown first: + handles = [blank_handle] + handles + labels = [text] + labels + + return handles, labels + + +def _adjust_legend_subtitles(legend): + """Make invisible-handle "subtitles" entries look more like titles.""" + plt = import_matplotlib_pyplot() + + # Legend title not in rcParams until 3.0 + font_size = plt.rcParams.get("legend.title_fontsize", None) + hpackers = legend.findobj(plt.matplotlib.offsetbox.VPacker)[0].get_children() + for hpack in hpackers: + draw_area, text_area = hpack.get_children() + handles = draw_area.get_children() + + # Assume that all artists that are not visible are + # subtitles: + if not all(artist.get_visible() for artist in handles): + # Remove the dummy marker which will bring the text + # more to the center: + draw_area.set_width(0) + for text in text_area.get_children(): + if font_size is not None: + # The sutbtitles should have the same font size + # as normal legend titles: + text.set_size(font_size) diff --git a/xarray/testing.py b/xarray/testing.py index ca72a4bee8e..40ca12852b9 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -1,14 +1,14 @@ """Testing functions exposed to the user API""" import functools +import warnings from typing import Hashable, Set, Union import numpy as np -import pandas as pd from xarray.core import duck_array_ops, formatting, utils from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset -from xarray.core.indexes import default_indexes +from xarray.core.indexes import Index, default_indexes from xarray.core.variable import IndexVariable, Variable __all__ = ( @@ -21,6 +21,21 @@ ) +def ensure_warnings(func): + # sometimes tests elevate warnings to errors + # -> make sure that does not happen in the assert_* functions + @functools.wraps(func) + def wrapper(*args, **kwargs): + __tracebackhide__ = True + + with warnings.catch_warnings(): + warnings.simplefilter("always") + + return func(*args, **kwargs) + + return wrapper + + def _decode_string_data(data): if data.dtype.kind == "S": return np.core.defchararray.decode(data, "utf-8", "replace") @@ -38,6 +53,7 @@ def _data_allclose_or_equiv(arr1, arr2, rtol=1e-05, atol=1e-08, decode_bytes=Tru return duck_array_ops.allclose_or_equiv(arr1, arr2, rtol=rtol, atol=atol) +@ensure_warnings def assert_equal(a, b): """Like :py:func:`numpy.testing.assert_array_equal`, but for xarray objects. @@ -54,9 +70,9 @@ def assert_equal(a, b): b : xarray.Dataset, xarray.DataArray or xarray.Variable The second object to compare. - See also + See Also -------- - assert_identical, assert_allclose, Dataset.equals, DataArray.equals, + assert_identical, assert_allclose, Dataset.equals, DataArray.equals numpy.testing.assert_array_equal """ __tracebackhide__ = True @@ -69,6 +85,7 @@ def assert_equal(a, b): raise TypeError("{} not supported by assertion comparison".format(type(a))) +@ensure_warnings def assert_identical(a, b): """Like :py:func:`xarray.testing.assert_equal`, but also matches the objects' names and attributes. @@ -82,7 +99,7 @@ def assert_identical(a, b): b : xarray.Dataset, xarray.DataArray or xarray.Variable The second object to compare. - See also + See Also -------- assert_equal, assert_allclose, Dataset.equals, DataArray.equals """ @@ -99,6 +116,7 @@ def assert_identical(a, b): raise TypeError("{} not supported by assertion comparison".format(type(a))) +@ensure_warnings def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): """Like :py:func:`numpy.testing.assert_allclose`, but for xarray objects. @@ -120,7 +138,7 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): This is useful for testing serialization methods on Python 3 that return saved strings as bytes. - See also + See Also -------- assert_identical, assert_equal, numpy.testing.assert_allclose """ @@ -182,18 +200,20 @@ def _format_message(x, y, err_msg, verbose): return "\n".join(parts) +@ensure_warnings def assert_duckarray_allclose( actual, desired, rtol=1e-07, atol=0, err_msg="", verbose=True ): - """ Like `np.testing.assert_allclose`, but for duckarrays. """ + """Like `np.testing.assert_allclose`, but for duckarrays.""" __tracebackhide__ = True allclose = duck_array_ops.allclose_or_equiv(actual, desired, rtol=rtol, atol=atol) assert allclose, _format_message(actual, desired, err_msg=err_msg, verbose=verbose) +@ensure_warnings def assert_duckarray_equal(x, y, err_msg="", verbose=True): - """ Like `np.testing.assert_array_equal`, but for duckarrays """ + """Like `np.testing.assert_array_equal`, but for duckarrays""" __tracebackhide__ = True if not utils.is_duck_array(x) and not utils.is_scalar(x): @@ -233,7 +253,7 @@ def assert_chunks_equal(a, b): def _assert_indexes_invariants_checks(indexes, possible_coord_variables, dims): assert isinstance(indexes, dict), indexes - assert all(isinstance(v, pd.Index) for v in indexes.values()), { + assert all(isinstance(v, Index) for v in indexes.values()), { k: type(v) for k, v in indexes.items() } diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 7c18f1a8c8a..d757fb451cc 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -1,17 +1,18 @@ import importlib import platform -import re import warnings from contextlib import contextmanager from distutils import version from unittest import mock # noqa: F401 import numpy as np +import pandas as pd import pytest from numpy.testing import assert_array_equal # noqa: F401 from pandas.testing import assert_frame_equal # noqa: F401 import xarray.testing +from xarray import Dataset from xarray.core import utils from xarray.core.duck_array_ops import allclose_or_equiv # noqa: F401 from xarray.core.indexing import ExplicitlyIndexed @@ -60,6 +61,9 @@ def LooseVersion(vstring): has_matplotlib, requires_matplotlib = _importorskip("matplotlib") +has_matplotlib_3_3_0, requires_matplotlib_3_3_0 = _importorskip( + "matplotlib", minversion="3.3.0" +) has_scipy, requires_scipy = _importorskip("scipy") has_pydap, requires_pydap = _importorskip("pydap.client") has_netCDF4, requires_netCDF4 = _importorskip("netCDF4") @@ -67,20 +71,23 @@ def LooseVersion(vstring): has_pynio, requires_pynio = _importorskip("Nio") has_pseudonetcdf, requires_pseudonetcdf = _importorskip("PseudoNetCDF") has_cftime, requires_cftime = _importorskip("cftime") -has_cftime_1_1_0, requires_cftime_1_1_0 = _importorskip("cftime", minversion="1.1.0.0") +has_cftime_1_4_1, requires_cftime_1_4_1 = _importorskip("cftime", minversion="1.4.1") has_dask, requires_dask = _importorskip("dask") has_bottleneck, requires_bottleneck = _importorskip("bottleneck") has_nc_time_axis, requires_nc_time_axis = _importorskip("nc_time_axis") has_rasterio, requires_rasterio = _importorskip("rasterio") has_zarr, requires_zarr = _importorskip("zarr") +has_fsspec, requires_fsspec = _importorskip("fsspec") has_iris, requires_iris = _importorskip("iris") has_cfgrib, requires_cfgrib = _importorskip("cfgrib") has_numbagg, requires_numbagg = _importorskip("numbagg") has_seaborn, requires_seaborn = _importorskip("seaborn") has_sparse, requires_sparse = _importorskip("sparse") +has_cupy, requires_cupy = _importorskip("cupy") has_cartopy, requires_cartopy = _importorskip("cartopy") # Need Pint 0.15 for __dask_tokenize__ tests for Quantity wrapped Dask Arrays has_pint_0_15, requires_pint_0_15 = _importorskip("pint", minversion="0.15") +has_numexpr, requires_numexpr = _importorskip("numexpr") # some special cases has_scipy_or_netCDF4 = has_scipy or has_netCDF4 @@ -133,18 +140,6 @@ def raise_if_dask_computes(max_computes=0): network = pytest.mark.network -@contextmanager -def raises_regex(error, pattern): - __tracebackhide__ = True - with pytest.raises(error) as excinfo: - yield - message = str(excinfo.value) - if not re.search(pattern, message): - raise AssertionError( - f"exception {excinfo.value!r} did not match pattern {pattern!r}" - ) - - class UnexpectedDataAccess(Exception): pass @@ -208,3 +203,42 @@ def assert_allclose(a, b, **kwargs): xarray.testing.assert_allclose(a, b, **kwargs) xarray.testing._assert_internal_invariants(a) xarray.testing._assert_internal_invariants(b) + + +def create_test_data(seed=None, add_attrs=True): + rs = np.random.RandomState(seed) + _vars = { + "var1": ["dim1", "dim2"], + "var2": ["dim1", "dim2"], + "var3": ["dim3", "dim1"], + } + _dims = {"dim1": 8, "dim2": 9, "dim3": 10} + + obj = Dataset() + obj["dim2"] = ("dim2", 0.5 * np.arange(_dims["dim2"])) + obj["dim3"] = ("dim3", list("abcdefghij")) + obj["time"] = ("time", pd.date_range("2000-01-01", periods=20)) + for v, dims in sorted(_vars.items()): + data = rs.normal(size=tuple(_dims[d] for d in dims)) + obj[v] = (dims, data) + if add_attrs: + obj[v].attrs = {"foo": "variable"} + obj.coords["numbers"] = ( + "dim3", + np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64"), + ) + obj.encoding = {"foo": "bar"} + assert all(obj.data.flags.writeable for obj in obj.variables.values()) + return obj + + +_CFTIME_CALENDARS = [ + "365_day", + "360_day", + "julian", + "all_leap", + "366_day", + "gregorian", + "proleptic_gregorian", + "standard", +] diff --git a/xarray/tests/conftest.py b/xarray/tests/conftest.py new file mode 100644 index 00000000000..7988b4a7b19 --- /dev/null +++ b/xarray/tests/conftest.py @@ -0,0 +1,8 @@ +import pytest + +from . import requires_dask + + +@pytest.fixture(params=["numpy", pytest.param("dask", marks=requires_dask)]) +def backend(request): + return request.param diff --git a/xarray/tests/test_accessor_dt.py b/xarray/tests/test_accessor_dt.py index 984bfc763bc..62da3bab2cd 100644 --- a/xarray/tests/test_accessor_dt.py +++ b/xarray/tests/test_accessor_dt.py @@ -12,7 +12,6 @@ assert_equal, assert_identical, raise_if_dask_computes, - raises_regex, requires_cftime, requires_dask, ) @@ -59,6 +58,8 @@ def setup(self): "weekday", "dayofyear", "quarter", + "date", + "time", "is_month_start", "is_month_end", "is_quarter_start", @@ -98,8 +99,8 @@ def test_field_access(self, field): def test_isocalendar(self, field, pandas_field): if LooseVersion(pd.__version__) < "1.1.0": - with raises_regex( - AttributeError, "'isocalendar' not available in pandas < 1.1.0" + with pytest.raises( + AttributeError, match=r"'isocalendar' not available in pandas < 1.1.0" ): self.data.time.dt.isocalendar()[field] return @@ -122,7 +123,7 @@ def test_not_datetime_type(self): nontime_data = self.data.copy() int_data = np.arange(len(self.data.time)).astype("int8") nontime_data = nontime_data.assign_coords(time=int_data) - with raises_regex(TypeError, "dt"): + with pytest.raises(TypeError, match=r"dt"): nontime_data.time.dt @pytest.mark.filterwarnings("ignore:dt.weekofyear and dt.week have been deprecated") @@ -144,6 +145,8 @@ def test_not_datetime_type(self): "weekday", "dayofyear", "quarter", + "date", + "time", "is_month_start", "is_month_end", "is_quarter_start", @@ -183,8 +186,8 @@ def test_isocalendar_dask(self, field): import dask.array as da if LooseVersion(pd.__version__) < "1.1.0": - with raises_regex( - AttributeError, "'isocalendar' not available in pandas < 1.1.0" + with pytest.raises( + AttributeError, match=r"'isocalendar' not available in pandas < 1.1.0" ): self.data.time.dt.isocalendar()[field] return @@ -289,7 +292,7 @@ def test_not_datetime_type(self): nontime_data = self.data.copy() int_data = np.arange(len(self.data.time)).astype("int8") nontime_data = nontime_data.assign_coords(time=int_data) - with raises_regex(TypeError, "dt"): + with pytest.raises(TypeError, match=r"dt"): nontime_data.time.dt @pytest.mark.parametrize( @@ -424,16 +427,26 @@ def test_field_access(data, field): @requires_cftime def test_isocalendar_cftime(data): - with raises_regex( - AttributeError, "'CFTimeIndex' object has no attribute 'isocalendar'" + with pytest.raises( + AttributeError, match=r"'CFTimeIndex' object has no attribute 'isocalendar'" ): data.time.dt.isocalendar() +@requires_cftime +def test_date_cftime(data): + + with pytest.raises( + AttributeError, + match=r"'CFTimeIndex' object has no attribute `date`. Consider using the floor method instead, for instance: `.time.dt.floor\('D'\)`.", + ): + data.time.dt.date() + + @requires_cftime @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test_cftime_strftime_access(data): - """ compare cftime formatting against datetime formatting """ + """compare cftime formatting against datetime formatting""" date_format = "%Y%m%d%H" result = data.time.dt.strftime(date_format) datetime_array = xr.DataArray( diff --git a/xarray/tests/test_accessor_str.py b/xarray/tests/test_accessor_str.py index e0cbdb7377a..519ca762c41 100644 --- a/xarray/tests/test_accessor_str.py +++ b/xarray/tests/test_accessor_str.py @@ -44,7 +44,7 @@ import xarray as xr -from . import assert_equal, requires_dask +from . import assert_equal, assert_identical, requires_dask @pytest.fixture(params=[np.str_, np.bytes_]) @@ -61,97 +61,363 @@ def test_dask(): result = xarr.str.len().compute() expected = xr.DataArray([1, 1, 1]) + assert result.dtype == expected.dtype assert_equal(result, expected) def test_count(dtype): values = xr.DataArray(["foo", "foofoo", "foooofooofommmfoo"]).astype(dtype) - result = values.str.count("f[o]+") + pat_str = dtype(r"f[o]+") + pat_re = re.compile(pat_str) + + result_str = values.str.count(pat_str) + result_re = values.str.count(pat_re) + expected = xr.DataArray([1, 2, 4]) - assert_equal(result, expected) + + assert result_str.dtype == expected.dtype + assert result_re.dtype == expected.dtype + assert_equal(result_str, expected) + assert_equal(result_re, expected) + + +def test_count_broadcast(dtype): + values = xr.DataArray(["foo", "foofoo", "foooofooofommmfoo"]).astype(dtype) + pat_str = np.array([r"f[o]+", r"o", r"m"]).astype(dtype) + pat_re = np.array([re.compile(x) for x in pat_str]) + + result_str = values.str.count(pat_str) + result_re = values.str.count(pat_re) + + expected = xr.DataArray([1, 4, 3]) + + assert result_str.dtype == expected.dtype + assert result_re.dtype == expected.dtype + assert_equal(result_str, expected) + assert_equal(result_re, expected) def test_contains(dtype): values = xr.DataArray(["Foo", "xYz", "fOOomMm__fOo", "MMM_"]).astype(dtype) + # case insensitive using regex - result = values.str.contains("FOO|mmm", case=False) + pat = values.dtype.type("FOO|mmm") + result = values.str.contains(pat, case=False) expected = xr.DataArray([True, False, True, True]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.contains(re.compile(pat, flags=re.IGNORECASE)) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # case sensitive using regex + pat = values.dtype.type("Foo|mMm") + result = values.str.contains(pat) + expected = xr.DataArray([True, False, True, False]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.contains(re.compile(pat)) + assert result.dtype == expected.dtype assert_equal(result, expected) + # case insensitive without regex result = values.str.contains("foo", regex=False, case=False) expected = xr.DataArray([True, False, True, False]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # case sensitive without regex + result = values.str.contains("fO", regex=False, case=True) + expected = xr.DataArray([False, False, True, False]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # regex regex=False + pat_re = re.compile("(/w+)") + with pytest.raises( + ValueError, + match="Must use regular expression matching for regular expression object.", + ): + values.str.contains(pat_re, regex=False) + + +def test_contains_broadcast(dtype): + values = xr.DataArray(["Foo", "xYz", "fOOomMm__fOo", "MMM_"], dims="X").astype( + dtype + ) + pat_str = xr.DataArray(["FOO|mmm", "Foo", "MMM"], dims="Y").astype(dtype) + pat_re = xr.DataArray([re.compile(x) for x in pat_str.data], dims="Y") + + # case insensitive using regex + result = values.str.contains(pat_str, case=False) + expected = xr.DataArray( + [ + [True, True, False], + [False, False, False], + [True, True, True], + [True, False, True], + ], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # case sensitive using regex + result = values.str.contains(pat_str) + expected = xr.DataArray( + [ + [False, True, False], + [False, False, False], + [False, False, False], + [False, False, True], + ], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.contains(pat_re) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # case insensitive without regex + result = values.str.contains(pat_str, regex=False, case=False) + expected = xr.DataArray( + [ + [False, True, False], + [False, False, False], + [False, True, True], + [False, False, True], + ], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # case insensitive with regex + result = values.str.contains(pat_str, regex=False, case=True) + expected = xr.DataArray( + [ + [False, True, False], + [False, False, False], + [False, False, False], + [False, False, True], + ], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype assert_equal(result, expected) def test_starts_ends_with(dtype): values = xr.DataArray(["om", "foo_nom", "nom", "bar_foo", "foo"]).astype(dtype) + result = values.str.startswith("foo") expected = xr.DataArray([False, True, False, False, True]) + assert result.dtype == expected.dtype assert_equal(result, expected) + result = values.str.endswith("foo") expected = xr.DataArray([False, False, False, True, True]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_starts_ends_with_broadcast(dtype): + values = xr.DataArray( + ["om", "foo_nom", "nom", "bar_foo", "foo_bar"], dims="X" + ).astype(dtype) + pat = xr.DataArray(["foo", "bar"], dims="Y").astype(dtype) + + result = values.str.startswith(pat) + expected = xr.DataArray( + [[False, False], [True, False], [False, False], [False, True], [True, False]], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.endswith(pat) + expected = xr.DataArray( + [[False, False], [False, False], [False, False], [True, False], [False, True]], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype assert_equal(result, expected) -def test_case(dtype): - da = xr.DataArray(["SOme word"]).astype(dtype) - capitalized = xr.DataArray(["Some word"]).astype(dtype) - lowered = xr.DataArray(["some word"]).astype(dtype) - swapped = xr.DataArray(["soME WORD"]).astype(dtype) - titled = xr.DataArray(["Some Word"]).astype(dtype) - uppered = xr.DataArray(["SOME WORD"]).astype(dtype) - assert_equal(da.str.capitalize(), capitalized) - assert_equal(da.str.lower(), lowered) - assert_equal(da.str.swapcase(), swapped) - assert_equal(da.str.title(), titled) - assert_equal(da.str.upper(), uppered) +def test_case_bytes(): + value = xr.DataArray(["SOme wOrd"]).astype(np.bytes_) + + exp_capitalized = xr.DataArray(["Some word"]).astype(np.bytes_) + exp_lowered = xr.DataArray(["some word"]).astype(np.bytes_) + exp_swapped = xr.DataArray(["soME WoRD"]).astype(np.bytes_) + exp_titled = xr.DataArray(["Some Word"]).astype(np.bytes_) + exp_uppered = xr.DataArray(["SOME WORD"]).astype(np.bytes_) + + res_capitalized = value.str.capitalize() + res_lowered = value.str.lower() + res_swapped = value.str.swapcase() + res_titled = value.str.title() + res_uppered = value.str.upper() + + assert res_capitalized.dtype == exp_capitalized.dtype + assert res_lowered.dtype == exp_lowered.dtype + assert res_swapped.dtype == exp_swapped.dtype + assert res_titled.dtype == exp_titled.dtype + assert res_uppered.dtype == exp_uppered.dtype + + assert_equal(res_capitalized, exp_capitalized) + assert_equal(res_lowered, exp_lowered) + assert_equal(res_swapped, exp_swapped) + assert_equal(res_titled, exp_titled) + assert_equal(res_uppered, exp_uppered) + + +def test_case_str(): + # This string includes some unicode characters + # that are common case management corner cases + value = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.unicode_) + + exp_capitalized = xr.DataArray(["Some word dž ß ᾓ σς ffi⁵å ç ⅰ"]).astype(np.unicode_) + exp_lowered = xr.DataArray(["some word dž ß ᾓ σς ffi⁵å ç ⅰ"]).astype(np.unicode_) + exp_swapped = xr.DataArray(["soME WoRD dž SS ᾛ σς FFI⁵å ç ⅰ"]).astype(np.unicode_) + exp_titled = xr.DataArray(["Some Word Dž Ss ᾛ Σς Ffi⁵Å Ç Ⅰ"]).astype(np.unicode_) + exp_uppered = xr.DataArray(["SOME WORD DŽ SS ἫΙ ΣΣ FFI⁵Å Ç Ⅰ"]).astype(np.unicode_) + exp_casefolded = xr.DataArray(["some word dž ss ἣι σσ ffi⁵å ç ⅰ"]).astype( + np.unicode_ + ) + + exp_norm_nfc = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.unicode_) + exp_norm_nfkc = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi5Å Ç I"]).astype(np.unicode_) + exp_norm_nfd = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.unicode_) + exp_norm_nfkd = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi5Å Ç I"]).astype( + np.unicode_ + ) + + res_capitalized = value.str.capitalize() + res_casefolded = value.str.casefold() + res_lowered = value.str.lower() + res_swapped = value.str.swapcase() + res_titled = value.str.title() + res_uppered = value.str.upper() + + res_norm_nfc = value.str.normalize("NFC") + res_norm_nfd = value.str.normalize("NFD") + res_norm_nfkc = value.str.normalize("NFKC") + res_norm_nfkd = value.str.normalize("NFKD") + + assert res_capitalized.dtype == exp_capitalized.dtype + assert res_casefolded.dtype == exp_casefolded.dtype + assert res_lowered.dtype == exp_lowered.dtype + assert res_swapped.dtype == exp_swapped.dtype + assert res_titled.dtype == exp_titled.dtype + assert res_uppered.dtype == exp_uppered.dtype + + assert res_norm_nfc.dtype == exp_norm_nfc.dtype + assert res_norm_nfd.dtype == exp_norm_nfd.dtype + assert res_norm_nfkc.dtype == exp_norm_nfkc.dtype + assert res_norm_nfkd.dtype == exp_norm_nfkd.dtype + + assert_equal(res_capitalized, exp_capitalized) + assert_equal(res_casefolded, exp_casefolded) + assert_equal(res_lowered, exp_lowered) + assert_equal(res_swapped, exp_swapped) + assert_equal(res_titled, exp_titled) + assert_equal(res_uppered, exp_uppered) + + assert_equal(res_norm_nfc, exp_norm_nfc) + assert_equal(res_norm_nfd, exp_norm_nfd) + assert_equal(res_norm_nfkc, exp_norm_nfkc) + assert_equal(res_norm_nfkd, exp_norm_nfkd) def test_replace(dtype): - values = xr.DataArray(["fooBAD__barBAD"]).astype(dtype) + values = xr.DataArray(["fooBAD__barBAD"], dims=["x"]).astype(dtype) result = values.str.replace("BAD[_]*", "") - expected = xr.DataArray(["foobar"]).astype(dtype) + expected = xr.DataArray(["foobar"], dims=["x"]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) result = values.str.replace("BAD[_]*", "", n=1) - expected = xr.DataArray(["foobarBAD"]).astype(dtype) + expected = xr.DataArray(["foobarBAD"], dims=["x"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + pat = xr.DataArray(["BAD[_]*", "AD[_]*"], dims=["y"]).astype(dtype) + result = values.str.replace(pat, "") + expected = xr.DataArray([["foobar", "fooBbarB"]], dims=["x", "y"]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) - s = xr.DataArray(["A", "B", "C", "Aaba", "Baca", "", "CABA", "dog", "cat"]).astype( + repl = xr.DataArray(["", "spam"], dims=["y"]).astype(dtype) + result = values.str.replace(pat, repl, n=1) + expected = xr.DataArray([["foobarBAD", "fooBspambarBAD"]], dims=["x", "y"]).astype( dtype ) - result = s.str.replace("A", "YYY") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + values = xr.DataArray( + ["A", "B", "C", "Aaba", "Baca", "", "CABA", "dog", "cat"] + ).astype(dtype) expected = xr.DataArray( ["YYY", "B", "C", "YYYaba", "Baca", "", "CYYYBYYY", "dog", "cat"] ).astype(dtype) + result = values.str.replace("A", "YYY") + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.replace("A", "YYY", regex=False) + assert result.dtype == expected.dtype assert_equal(result, expected) - result = s.str.replace("A", "YYY", case=False) + result = values.str.replace("A", "YYY", case=False) expected = xr.DataArray( ["YYY", "B", "C", "YYYYYYbYYY", "BYYYcYYY", "", "CYYYBYYY", "dog", "cYYYt"] ).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) - result = s.str.replace("^.a|dog", "XX-XX ", case=False) + result = values.str.replace("^.a|dog", "XX-XX ", case=False) expected = xr.DataArray( ["A", "B", "C", "XX-XX ba", "XX-XX ca", "", "XX-XX BA", "XX-XX ", "XX-XX t"] ).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) def test_replace_callable(): values = xr.DataArray(["fooBAD__barBAD"]) + # test with callable repl = lambda m: m.group(0).swapcase() result = values.str.replace("[a-z][A-Z]{2}", repl, n=2) exp = xr.DataArray(["foObaD__baRbaD"]) + assert result.dtype == exp.dtype assert_equal(result, exp) + # test regex named groups values = xr.DataArray(["Foo Bar Baz"]) pat = r"(?P\w+) (?P\w+) (?P\w+)" repl = lambda m: m.group("middle").swapcase() result = values.str.replace(pat, repl) exp = xr.DataArray(["bAR"]) + assert result.dtype == exp.dtype + assert_equal(result, exp) + + # test broadcast + values = xr.DataArray(["Foo Bar Baz"], dims=["x"]) + pat = r"(?P\w+) (?P\w+) (?P\w+)" + repl = xr.DataArray( + [ + lambda m: m.group("first").swapcase(), + lambda m: m.group("middle").swapcase(), + lambda m: m.group("last").swapcase(), + ], + dims=["Y"], + ) + result = values.str.replace(pat, repl) + exp = xr.DataArray([["fOO", "bAR", "bAZ"]], dims=["x", "Y"]) + assert result.dtype == exp.dtype assert_equal(result, exp) @@ -161,19 +427,54 @@ def test_replace_unicode(): expected = xr.DataArray([b"abcd, \xc3\xa0".decode("utf-8")]) pat = re.compile(r"(?<=\w),(?=\w)", flags=re.UNICODE) result = values.str.replace(pat, ", ") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # broadcast version + values = xr.DataArray([b"abcd,\xc3\xa0".decode("utf-8")], dims=["X"]) + expected = xr.DataArray( + [[b"abcd, \xc3\xa0".decode("utf-8"), b"BAcd,\xc3\xa0".decode("utf-8")]], + dims=["X", "Y"], + ) + pat = xr.DataArray( + [re.compile(r"(?<=\w),(?=\w)", flags=re.UNICODE), r"ab"], dims=["Y"] + ) + repl = xr.DataArray([", ", "BA"], dims=["Y"]) + result = values.str.replace(pat, repl) + assert result.dtype == expected.dtype assert_equal(result, expected) def test_replace_compiled_regex(dtype): - values = xr.DataArray(["fooBAD__barBAD"]).astype(dtype) + values = xr.DataArray(["fooBAD__barBAD"], dims=["x"]).astype(dtype) + # test with compiled regex pat = re.compile(dtype("BAD[_]*")) result = values.str.replace(pat, "") - expected = xr.DataArray(["foobar"]).astype(dtype) + expected = xr.DataArray(["foobar"], dims=["x"]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) result = values.str.replace(pat, "", n=1) - expected = xr.DataArray(["foobarBAD"]).astype(dtype) + expected = xr.DataArray(["foobarBAD"], dims=["x"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # broadcast + pat = xr.DataArray( + [re.compile(dtype("BAD[_]*")), re.compile(dtype("AD[_]*"))], dims=["y"] + ) + result = values.str.replace(pat, "") + expected = xr.DataArray([["foobar", "fooBbarB"]], dims=["x", "y"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + repl = xr.DataArray(["", "spam"], dims=["y"]).astype(dtype) + result = values.str.replace(pat, repl, n=1) + expected = xr.DataArray([["foobarBAD", "fooBspambarBAD"]], dims=["x", "y"]).astype( + dtype + ) + assert result.dtype == expected.dtype assert_equal(result, expected) # case and flags provided to str.replace will have no effect @@ -181,13 +482,19 @@ def test_replace_compiled_regex(dtype): values = xr.DataArray(["fooBAD__barBAD__bad"]).astype(dtype) pat = re.compile(dtype("BAD[_]*")) - with pytest.raises(ValueError, match="case and flags cannot be"): + with pytest.raises( + ValueError, match="Flags cannot be set when pat is a compiled regex." + ): result = values.str.replace(pat, "", flags=re.IGNORECASE) - with pytest.raises(ValueError, match="case and flags cannot be"): + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): result = values.str.replace(pat, "", case=False) - with pytest.raises(ValueError, match="case and flags cannot be"): + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): result = values.str.replace(pat, "", case=True) # test with callable @@ -196,18 +503,37 @@ def test_replace_compiled_regex(dtype): pat = re.compile(dtype("[a-z][A-Z]{2}")) result = values.str.replace(pat, repl, n=2) expected = xr.DataArray(["foObaD__baRbaD"]).astype(dtype) + assert result.dtype == expected.dtype assert_equal(result, expected) def test_replace_literal(dtype): # GH16808 literal replace (regex=False vs regex=True) - values = xr.DataArray(["f.o", "foo"]).astype(dtype) - expected = xr.DataArray(["bao", "bao"]).astype(dtype) + values = xr.DataArray(["f.o", "foo"], dims=["X"]).astype(dtype) + expected = xr.DataArray(["bao", "bao"], dims=["X"]).astype(dtype) result = values.str.replace("f.", "ba") + assert result.dtype == expected.dtype assert_equal(result, expected) - expected = xr.DataArray(["bao", "foo"]).astype(dtype) + expected = xr.DataArray(["bao", "foo"], dims=["X"]).astype(dtype) result = values.str.replace("f.", "ba", regex=False) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # Broadcast + pat = xr.DataArray(["f.", ".o"], dims=["yy"]).astype(dtype) + expected = xr.DataArray([["bao", "fba"], ["bao", "bao"]], dims=["X", "yy"]).astype( + dtype + ) + result = values.str.replace(pat, "ba") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + expected = xr.DataArray([["bao", "fba"], ["foo", "foo"]], dims=["X", "yy"]).astype( + dtype + ) + result = values.str.replace(pat, "ba", regex=False) + assert result.dtype == expected.dtype assert_equal(result, expected) # Cannot do a literal replace if given a callable repl or compiled @@ -224,479 +550,3062 @@ def test_replace_literal(dtype): values.str.replace(compiled_pat, "", regex=False) -def test_repeat(dtype): - values = xr.DataArray(["a", "b", "c", "d"]).astype(dtype) - result = values.str.repeat(3) - expected = xr.DataArray(["aaa", "bbb", "ccc", "ddd"]).astype(dtype) - assert_equal(result, expected) +def test_extract_extractall_findall_empty_raises(dtype): + pat_str = dtype(r".*") + pat_re = re.compile(pat_str) + value = xr.DataArray([["a"]], dims=["X", "Y"]).astype(dtype) -def test_match(dtype): - # New match behavior introduced in 0.13 - values = xr.DataArray(["fooBAD__barBAD", "foo"]).astype(dtype) - result = values.str.match(".*(BAD[_]+).*(BAD)") - expected = xr.DataArray([True, False]) - assert_equal(result, expected) + with pytest.raises(ValueError, match="No capture groups found in pattern."): + value.str.extract(pat=pat_str, dim="ZZ") - values = xr.DataArray(["fooBAD__barBAD", "foo"]).astype(dtype) - result = values.str.match(".*BAD[_]+.*BAD") - expected = xr.DataArray([True, False]) - assert_equal(result, expected) + with pytest.raises(ValueError, match="No capture groups found in pattern."): + value.str.extract(pat=pat_re, dim="ZZ") + with pytest.raises(ValueError, match="No capture groups found in pattern."): + value.str.extractall(pat=pat_str, group_dim="XX", match_dim="YY") -def test_empty_str_methods(): - empty = xr.DataArray(np.empty(shape=(0,), dtype="U")) - empty_str = empty - empty_int = xr.DataArray(np.empty(shape=(0,), dtype=int)) - empty_bool = xr.DataArray(np.empty(shape=(0,), dtype=bool)) - empty_bytes = xr.DataArray(np.empty(shape=(0,), dtype="S")) + with pytest.raises(ValueError, match="No capture groups found in pattern."): + value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") - assert_equal(empty_str, empty.str.title()) - assert_equal(empty_int, empty.str.count("a")) - assert_equal(empty_bool, empty.str.contains("a")) - assert_equal(empty_bool, empty.str.startswith("a")) - assert_equal(empty_bool, empty.str.endswith("a")) - assert_equal(empty_str, empty.str.lower()) - assert_equal(empty_str, empty.str.upper()) - assert_equal(empty_str, empty.str.replace("a", "b")) - assert_equal(empty_str, empty.str.repeat(3)) - assert_equal(empty_bool, empty.str.match("^a")) - assert_equal(empty_int, empty.str.len()) - assert_equal(empty_int, empty.str.find("a")) - assert_equal(empty_int, empty.str.rfind("a")) - assert_equal(empty_str, empty.str.pad(42)) - assert_equal(empty_str, empty.str.center(42)) - assert_equal(empty_str, empty.str.slice(stop=1)) - assert_equal(empty_str, empty.str.slice(step=1)) - assert_equal(empty_str, empty.str.strip()) - assert_equal(empty_str, empty.str.lstrip()) - assert_equal(empty_str, empty.str.rstrip()) - assert_equal(empty_str, empty.str.wrap(42)) - assert_equal(empty_str, empty.str.get(0)) - assert_equal(empty_str, empty_bytes.str.decode("ascii")) - assert_equal(empty_bytes, empty.str.encode("ascii")) - assert_equal(empty_str, empty.str.isalnum()) - assert_equal(empty_str, empty.str.isalpha()) - assert_equal(empty_str, empty.str.isdigit()) - assert_equal(empty_str, empty.str.isspace()) - assert_equal(empty_str, empty.str.islower()) - assert_equal(empty_str, empty.str.isupper()) - assert_equal(empty_str, empty.str.istitle()) - assert_equal(empty_str, empty.str.isnumeric()) - assert_equal(empty_str, empty.str.isdecimal()) - assert_equal(empty_str, empty.str.capitalize()) - assert_equal(empty_str, empty.str.swapcase()) - table = str.maketrans("a", "b") - assert_equal(empty_str, empty.str.translate(table)) + with pytest.raises(ValueError, match="No capture groups found in pattern."): + value.str.findall(pat=pat_str) + with pytest.raises(ValueError, match="No capture groups found in pattern."): + value.str.findall(pat=pat_re) -def test_ismethods(dtype): - values = ["A", "b", "Xy", "4", "3A", "", "TT", "55", "-", " "] - str_s = xr.DataArray(values).astype(dtype) - alnum_e = [True, True, True, True, True, False, True, True, False, False] - alpha_e = [True, True, True, False, False, False, True, False, False, False] - digit_e = [False, False, False, True, False, False, False, True, False, False] - space_e = [False, False, False, False, False, False, False, False, False, True] - lower_e = [False, True, False, False, False, False, False, False, False, False] - upper_e = [True, False, False, False, True, False, True, False, False, False] - title_e = [True, False, True, False, True, False, False, False, False, False] - - assert_equal(str_s.str.isalnum(), xr.DataArray(alnum_e)) - assert_equal(str_s.str.isalpha(), xr.DataArray(alpha_e)) - assert_equal(str_s.str.isdigit(), xr.DataArray(digit_e)) - assert_equal(str_s.str.isspace(), xr.DataArray(space_e)) - assert_equal(str_s.str.islower(), xr.DataArray(lower_e)) - assert_equal(str_s.str.isupper(), xr.DataArray(upper_e)) - assert_equal(str_s.str.istitle(), xr.DataArray(title_e)) +def test_extract_multi_None_raises(dtype): + pat_str = r"(\w+)_(\d+)" + pat_re = re.compile(pat_str) -def test_isnumeric(): - # 0x00bc: ¼ VULGAR FRACTION ONE QUARTER - # 0x2605: ★ not number - # 0x1378: ፸ ETHIOPIC NUMBER SEVENTY - # 0xFF13: 3 Em 3 - values = ["A", "3", "¼", "★", "፸", "3", "four"] - s = xr.DataArray(values) - numeric_e = [False, True, True, False, True, True, False] - decimal_e = [False, True, False, False, False, True, False] - assert_equal(s.str.isnumeric(), xr.DataArray(numeric_e)) - assert_equal(s.str.isdecimal(), xr.DataArray(decimal_e)) + value = xr.DataArray([["a_b"]], dims=["X", "Y"]).astype(dtype) + with pytest.raises( + ValueError, + match="Dimension must be specified if more than one capture group is given.", + ): + value.str.extract(pat=pat_str, dim=None) -def test_len(dtype): - values = ["foo", "fooo", "fooooo", "fooooooo"] - result = xr.DataArray(values).astype(dtype).str.len() - expected = xr.DataArray([len(x) for x in values]) - assert_equal(result, expected) + with pytest.raises( + ValueError, + match="Dimension must be specified if more than one capture group is given.", + ): + value.str.extract(pat=pat_re, dim=None) -def test_find(dtype): - values = xr.DataArray(["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF", "XXX"]) - values = values.astype(dtype) - result = values.str.find("EF") - assert_equal(result, xr.DataArray([4, 3, 1, 0, -1])) - expected = xr.DataArray([v.find(dtype("EF")) for v in values.values]) - assert_equal(result, expected) +def test_extract_extractall_findall_case_re_raises(dtype): + pat_str = r".*" + pat_re = re.compile(pat_str) - result = values.str.rfind("EF") - assert_equal(result, xr.DataArray([4, 5, 7, 4, -1])) - expected = xr.DataArray([v.rfind(dtype("EF")) for v in values.values]) - assert_equal(result, expected) + value = xr.DataArray([["a"]], dims=["X", "Y"]).astype(dtype) - result = values.str.find("EF", 3) - assert_equal(result, xr.DataArray([4, 3, 7, 4, -1])) - expected = xr.DataArray([v.find(dtype("EF"), 3) for v in values.values]) - assert_equal(result, expected) + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): + value.str.extract(pat=pat_re, case=True, dim="ZZ") - result = values.str.rfind("EF", 3) - assert_equal(result, xr.DataArray([4, 5, 7, 4, -1])) - expected = xr.DataArray([v.rfind(dtype("EF"), 3) for v in values.values]) - assert_equal(result, expected) + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): + value.str.extract(pat=pat_re, case=False, dim="ZZ") - result = values.str.find("EF", 3, 6) - assert_equal(result, xr.DataArray([4, 3, -1, 4, -1])) - expected = xr.DataArray([v.find(dtype("EF"), 3, 6) for v in values.values]) - assert_equal(result, expected) + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): + value.str.extractall(pat=pat_re, case=True, group_dim="XX", match_dim="YY") - result = values.str.rfind("EF", 3, 6) - assert_equal(result, xr.DataArray([4, 3, -1, 4, -1])) - xp = xr.DataArray([v.rfind(dtype("EF"), 3, 6) for v in values.values]) - assert_equal(result, xp) + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): + value.str.extractall(pat=pat_re, case=False, group_dim="XX", match_dim="YY") + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): + value.str.findall(pat=pat_re, case=True) -def test_index(dtype): - s = xr.DataArray(["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF"]).astype(dtype) + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): + value.str.findall(pat=pat_re, case=False) - result = s.str.index("EF") - assert_equal(result, xr.DataArray([4, 3, 1, 0])) - result = s.str.rindex("EF") - assert_equal(result, xr.DataArray([4, 5, 7, 4])) +def test_extract_extractall_name_collision_raises(dtype): + pat_str = r"(\w+)" + pat_re = re.compile(pat_str) - result = s.str.index("EF", 3) - assert_equal(result, xr.DataArray([4, 3, 7, 4])) + value = xr.DataArray([["a"]], dims=["X", "Y"]).astype(dtype) - result = s.str.rindex("EF", 3) - assert_equal(result, xr.DataArray([4, 5, 7, 4])) + with pytest.raises(KeyError, match="Dimension 'X' already present in DataArray."): + value.str.extract(pat=pat_str, dim="X") - result = s.str.index("E", 4, 8) - assert_equal(result, xr.DataArray([4, 5, 7, 4])) + with pytest.raises(KeyError, match="Dimension 'X' already present in DataArray."): + value.str.extract(pat=pat_re, dim="X") - result = s.str.rindex("E", 0, 5) - assert_equal(result, xr.DataArray([4, 3, 1, 4])) + with pytest.raises( + KeyError, match="Group dimension 'X' already present in DataArray." + ): + value.str.extractall(pat=pat_str, group_dim="X", match_dim="ZZ") - with pytest.raises(ValueError): - result = s.str.index("DE") + with pytest.raises( + KeyError, match="Group dimension 'X' already present in DataArray." + ): + value.str.extractall(pat=pat_re, group_dim="X", match_dim="YY") + with pytest.raises( + KeyError, match="Match dimension 'Y' already present in DataArray." + ): + value.str.extractall(pat=pat_str, group_dim="XX", match_dim="Y") -def test_pad(dtype): - values = xr.DataArray(["a", "b", "c", "eeeee"]).astype(dtype) + with pytest.raises( + KeyError, match="Match dimension 'Y' already present in DataArray." + ): + value.str.extractall(pat=pat_re, group_dim="XX", match_dim="Y") - result = values.str.pad(5, side="left") - expected = xr.DataArray([" a", " b", " c", "eeeee"]).astype(dtype) - assert_equal(result, expected) + with pytest.raises( + KeyError, match="Group dimension 'ZZ' is the same as match dimension 'ZZ'." + ): + value.str.extractall(pat=pat_str, group_dim="ZZ", match_dim="ZZ") - result = values.str.pad(5, side="right") - expected = xr.DataArray(["a ", "b ", "c ", "eeeee"]).astype(dtype) - assert_equal(result, expected) + with pytest.raises( + KeyError, match="Group dimension 'ZZ' is the same as match dimension 'ZZ'." + ): + value.str.extractall(pat=pat_re, group_dim="ZZ", match_dim="ZZ") - result = values.str.pad(5, side="both") - expected = xr.DataArray([" a ", " b ", " c ", "eeeee"]).astype(dtype) - assert_equal(result, expected) +def test_extract_single_case(dtype): + pat_str = r"(\w+)_Xy_\d*" + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re) -def test_pad_fillchar(dtype): - values = xr.DataArray(["a", "b", "c", "eeeee"]).astype(dtype) + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) - result = values.str.pad(5, side="left", fillchar="X") - expected = xr.DataArray(["XXXXa", "XXXXb", "XXXXc", "eeeee"]).astype(dtype) - assert_equal(result, expected) + targ_none = xr.DataArray( + [["a", "bab", "abc"], ["abcd", "", "abcdef"]], dims=["X", "Y"] + ).astype(dtype) + targ_dim = xr.DataArray( + [[["a"], ["bab"], ["abc"]], [["abcd"], [""], ["abcdef"]]], dims=["X", "Y", "XX"] + ).astype(dtype) - result = values.str.pad(5, side="right", fillchar="X") - expected = xr.DataArray(["aXXXX", "bXXXX", "cXXXX", "eeeee"]).astype(dtype) - assert_equal(result, expected) + res_str_none = value.str.extract(pat=pat_str, dim=None) + res_str_dim = value.str.extract(pat=pat_str, dim="XX") + res_str_none_case = value.str.extract(pat=pat_str, dim=None, case=True) + res_str_dim_case = value.str.extract(pat=pat_str, dim="XX", case=True) + res_re_none = value.str.extract(pat=pat_re, dim=None) + res_re_dim = value.str.extract(pat=pat_re, dim="XX") + + assert res_str_none.dtype == targ_none.dtype + assert res_str_dim.dtype == targ_dim.dtype + assert res_str_none_case.dtype == targ_none.dtype + assert res_str_dim_case.dtype == targ_dim.dtype + assert res_re_none.dtype == targ_none.dtype + assert res_re_dim.dtype == targ_dim.dtype + + assert_equal(res_str_none, targ_none) + assert_equal(res_str_dim, targ_dim) + assert_equal(res_str_none_case, targ_none) + assert_equal(res_str_dim_case, targ_dim) + assert_equal(res_re_none, targ_none) + assert_equal(res_re_dim, targ_dim) + + +def test_extract_single_nocase(dtype): + pat_str = r"(\w+)?_Xy_\d*" + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re, flags=re.IGNORECASE) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "_Xy_1", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) - result = values.str.pad(5, side="both", fillchar="X") - expected = xr.DataArray(["XXaXX", "XXbXX", "XXcXX", "eeeee"]).astype(dtype) - assert_equal(result, expected) + targ_none = xr.DataArray( + [["a", "ab", "abc"], ["abcd", "", "abcdef"]], dims=["X", "Y"] + ).astype(dtype) + targ_dim = xr.DataArray( + [[["a"], ["ab"], ["abc"]], [["abcd"], [""], ["abcdef"]]], dims=["X", "Y", "XX"] + ).astype(dtype) - msg = "fillchar must be a character, not str" - with pytest.raises(TypeError, match=msg): - result = values.str.pad(5, fillchar="XY") + res_str_none = value.str.extract(pat=pat_str, dim=None, case=False) + res_str_dim = value.str.extract(pat=pat_str, dim="XX", case=False) + res_re_none = value.str.extract(pat=pat_re, dim=None) + res_re_dim = value.str.extract(pat=pat_re, dim="XX") + assert res_re_dim.dtype == targ_none.dtype + assert res_str_dim.dtype == targ_dim.dtype + assert res_re_none.dtype == targ_none.dtype + assert res_re_dim.dtype == targ_dim.dtype -def test_translate(): - values = xr.DataArray(["abcdefg", "abcc", "cdddfg", "cdefggg"]) - table = str.maketrans("abc", "cde") - result = values.str.translate(table) - expected = xr.DataArray(["cdedefg", "cdee", "edddfg", "edefggg"]) - assert_equal(result, expected) + assert_equal(res_str_none, targ_none) + assert_equal(res_str_dim, targ_dim) + assert_equal(res_re_none, targ_none) + assert_equal(res_re_dim, targ_dim) -def test_center_ljust_rjust(dtype): - values = xr.DataArray(["a", "b", "c", "eeeee"]).astype(dtype) +def test_extract_multi_case(dtype): + pat_str = r"(\w+)_Xy_(\d*)" + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re) - result = values.str.center(5) - expected = xr.DataArray([" a ", " b ", " c ", "eeeee"]).astype(dtype) - assert_equal(result, expected) + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) - result = values.str.ljust(5) - expected = xr.DataArray(["a ", "b ", "c ", "eeeee"]).astype(dtype) - assert_equal(result, expected) + expected = xr.DataArray( + [ + [["a", "0"], ["bab", "110"], ["abc", "01"]], + [["abcd", ""], ["", ""], ["abcdef", "101"]], + ], + dims=["X", "Y", "XX"], + ).astype(dtype) - result = values.str.rjust(5) - expected = xr.DataArray([" a", " b", " c", "eeeee"]).astype(dtype) - assert_equal(result, expected) + res_str = value.str.extract(pat=pat_str, dim="XX") + res_re = value.str.extract(pat=pat_re, dim="XX") + res_str_case = value.str.extract(pat=pat_str, dim="XX", case=True) + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype -def test_center_ljust_rjust_fillchar(dtype): - values = xr.DataArray(["a", "bb", "cccc", "ddddd", "eeeeee"]).astype(dtype) - result = values.str.center(5, fillchar="X") - expected = xr.DataArray(["XXaXX", "XXbbX", "Xcccc", "ddddd", "eeeeee"]) - assert_equal(result, expected.astype(dtype)) + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) - result = values.str.ljust(5, fillchar="X") - expected = xr.DataArray(["aXXXX", "bbXXX", "ccccX", "ddddd", "eeeeee"]) - assert_equal(result, expected.astype(dtype)) - result = values.str.rjust(5, fillchar="X") - expected = xr.DataArray(["XXXXa", "XXXbb", "Xcccc", "ddddd", "eeeeee"]) - assert_equal(result, expected.astype(dtype)) +def test_extract_multi_nocase(dtype): + pat_str = r"(\w+)_Xy_(\d*)" + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re, flags=re.IGNORECASE) - # If fillchar is not a charatter, normal str raises TypeError - # 'aaa'.ljust(5, 'XY') - # TypeError: must be char, not str - template = "fillchar must be a character, not {dtype}" + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) - with pytest.raises(TypeError, match=template.format(dtype="str")): - values.str.center(5, fillchar="XY") + expected = xr.DataArray( + [ + [["a", "0"], ["ab", "10"], ["abc", "01"]], + [["abcd", ""], ["", ""], ["abcdef", "101"]], + ], + dims=["X", "Y", "XX"], + ).astype(dtype) - with pytest.raises(TypeError, match=template.format(dtype="str")): - values.str.ljust(5, fillchar="XY") + res_str = value.str.extract(pat=pat_str, dim="XX", case=False) + res_re = value.str.extract(pat=pat_re, dim="XX") - with pytest.raises(TypeError, match=template.format(dtype="str")): - values.str.rjust(5, fillchar="XY") + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert_equal(res_str, expected) + assert_equal(res_re, expected) -def test_zfill(dtype): - values = xr.DataArray(["1", "22", "aaa", "333", "45678"]).astype(dtype) - result = values.str.zfill(5) - expected = xr.DataArray(["00001", "00022", "00aaa", "00333", "45678"]) - assert_equal(result, expected.astype(dtype)) +def test_extract_broadcast(dtype): + value = xr.DataArray( + ["a_Xy_0", "ab_xY_10", "abc_Xy_01"], + dims=["X"], + ).astype(dtype) - result = values.str.zfill(3) - expected = xr.DataArray(["001", "022", "aaa", "333", "45678"]) - assert_equal(result, expected.astype(dtype)) + pat_str = xr.DataArray( + [r"(\w+)_Xy_(\d*)", r"(\w+)_xY_(\d*)"], + dims=["Y"], + ).astype(dtype) + pat_re = value.str._re_compile(pat=pat_str) + expected = [ + [["a", "0"], ["", ""]], + [["", ""], ["ab", "10"]], + [["abc", "01"], ["", ""]], + ] + expected = xr.DataArray(expected, dims=["X", "Y", "Zz"]).astype(dtype) -def test_slice(dtype): - arr = xr.DataArray(["aafootwo", "aabartwo", "aabazqux"]).astype(dtype) + res_str = value.str.extract(pat=pat_str, dim="Zz") + res_re = value.str.extract(pat=pat_re, dim="Zz") - result = arr.str.slice(2, 5) - exp = xr.DataArray(["foo", "bar", "baz"]).astype(dtype) - assert_equal(result, exp) + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype - for start, stop, step in [(0, 3, -1), (None, None, -1), (3, 10, 2), (3, 0, -1)]: - try: - result = arr.str[start:stop:step] - expected = xr.DataArray([s[start:stop:step] for s in arr.values]) - assert_equal(result, expected.astype(dtype)) - except IndexError: - print(f"failed on {start}:{stop}:{step}") - raise + assert_equal(res_str, expected) + assert_equal(res_re, expected) -def test_slice_replace(dtype): - da = lambda x: xr.DataArray(x).astype(dtype) - values = da(["short", "a bit longer", "evenlongerthanthat", ""]) +def test_extractall_single_single_case(dtype): + pat_str = r"(\w+)_Xy_\d*" + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re) - expected = da(["shrt", "a it longer", "evnlongerthanthat", ""]) - result = values.str.slice_replace(2, 3) - assert_equal(result, expected) + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) - expected = da(["shzrt", "a zit longer", "evznlongerthanthat", "z"]) - result = values.str.slice_replace(2, 3, "z") - assert_equal(result, expected) + expected = xr.DataArray( + [[[["a"]], [[""]], [["abc"]]], [[["abcd"]], [[""]], [["abcdef"]]]], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) - expected = da(["shzort", "a zbit longer", "evzenlongerthanthat", "z"]) - result = values.str.slice_replace(2, 2, "z") - assert_equal(result, expected) + res_str = value.str.extractall(pat=pat_str, group_dim="XX", match_dim="YY") + res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") + res_str_case = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=True + ) - expected = da(["shzort", "a zbit longer", "evzenlongerthanthat", "z"]) - result = values.str.slice_replace(2, 1, "z") - assert_equal(result, expected) + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype - expected = da(["shorz", "a bit longez", "evenlongerthanthaz", "z"]) - result = values.str.slice_replace(-1, None, "z") - assert_equal(result, expected) + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) - expected = da(["zrt", "zer", "zat", "z"]) - result = values.str.slice_replace(None, -2, "z") - assert_equal(result, expected) - expected = da(["shortz", "a bit znger", "evenlozerthanthat", "z"]) - result = values.str.slice_replace(6, 8, "z") - assert_equal(result, expected) +def test_extractall_single_single_nocase(dtype): + pat_str = r"(\w+)_Xy_\d*" + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re, flags=re.I) - expected = da(["zrt", "a zit longer", "evenlongzerthanthat", "z"]) - result = values.str.slice_replace(-10, 3, "z") - assert_equal(result, expected) + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) + expected = xr.DataArray( + [[[["a"]], [["ab"]], [["abc"]]], [[["abcd"]], [[""]], [["abcdef"]]]], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) -def test_strip_lstrip_rstrip(dtype): - values = xr.DataArray([" aa ", " bb \n", "cc "]).astype(dtype) + res_str = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=False + ) + res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") - result = values.str.strip() - expected = xr.DataArray(["aa", "bb", "cc"]).astype(dtype) - assert_equal(result, expected) + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype - result = values.str.lstrip() - expected = xr.DataArray(["aa ", "bb \n", "cc "]).astype(dtype) - assert_equal(result, expected) + assert_equal(res_str, expected) + assert_equal(res_re, expected) - result = values.str.rstrip() - expected = xr.DataArray([" aa", " bb", "cc"]).astype(dtype) - assert_equal(result, expected) +def test_extractall_single_multi_case(dtype): + pat_str = r"(\w+)_Xy_\d*" + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re) -def test_strip_lstrip_rstrip_args(dtype): - values = xr.DataArray(["xxABCxx", "xx BNSD", "LDFJH xx"]).astype(dtype) + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) - rs = values.str.strip("x") - xp = xr.DataArray(["ABC", " BNSD", "LDFJH "]).astype(dtype) - assert_equal(rs, xp) + expected = xr.DataArray( + [ + [[["a"], [""], [""]], [["bab"], ["baab"], [""]], [["abc"], ["cbc"], [""]]], + [ + [["abcd"], ["dcd"], ["dccd"]], + [[""], [""], [""]], + [["abcdef"], ["fef"], [""]], + ], + ], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) - rs = values.str.lstrip("x") - xp = xr.DataArray(["ABCxx", " BNSD", "LDFJH xx"]).astype(dtype) - assert_equal(rs, xp) + res_str = value.str.extractall(pat=pat_str, group_dim="XX", match_dim="YY") + res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") + res_str_case = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=True + ) - rs = values.str.rstrip("x") - xp = xr.DataArray(["xxABC", "xx BNSD", "LDFJH "]).astype(dtype) - assert_equal(rs, xp) + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) -def test_wrap(): - # test values are: two words less than width, two words equal to width, - # two words greater than width, one word less than width, one word - # equal to width, one word greater than width, multiple tokens with - # trailing whitespace equal to width - values = xr.DataArray( + +def test_extractall_single_multi_nocase(dtype): + pat_str = r"(\w+)_Xy_\d*" + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re, flags=re.I) + + value = xr.DataArray( [ - "hello world", - "hello world!", - "hello world!!", - "abcdefabcde", - "abcdefabcdef", - "abcdefabcdefa", - "ab ab ab ab ", - "ab ab ab ab a", - "\t", - ] - ) + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) - # expected values expected = xr.DataArray( [ - "hello world", - "hello world!", - "hello\nworld!!", - "abcdefabcde", - "abcdefabcdef", - "abcdefabcdef\na", - "ab ab ab ab", - "ab ab ab ab\na", - "", - ] + [ + [["a"], [""], [""]], + [["ab"], ["bab"], ["baab"]], + [["abc"], ["cbc"], [""]], + ], + [ + [["abcd"], ["dcd"], ["dccd"]], + [[""], [""], [""]], + [["abcdef"], ["fef"], [""]], + ], + ], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) + + res_str = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=False ) + res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") - result = values.str.wrap(12, break_long_words=True) - assert_equal(result, expected) + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype - # test with pre and post whitespace (non-unicode), NaN, and non-ascii - # Unicode - values = xr.DataArray([" pre ", "\xac\u20ac\U00008000 abadcafe"]) - expected = xr.DataArray([" pre", "\xac\u20ac\U00008000 ab\nadcafe"]) - result = values.str.wrap(6) - assert_equal(result, expected) + assert_equal(res_str, expected) + assert_equal(res_re, expected) -def test_wrap_kwargs_passed(): - # GH4334 +def test_extractall_multi_single_case(dtype): + pat_str = r"(\w+)_Xy_(\d*)" + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re) - values = xr.DataArray(" hello world ") + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) - result = values.str.wrap(7) - expected = xr.DataArray(" hello\nworld") - assert_equal(result, expected) + expected = xr.DataArray( + [ + [[["a", "0"]], [["", ""]], [["abc", "01"]]], + [[["abcd", ""]], [["", ""]], [["abcdef", "101"]]], + ], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) - result = values.str.wrap(7, drop_whitespace=False) - expected = xr.DataArray(" hello\n world\n ") - assert_equal(result, expected) + res_str = value.str.extractall(pat=pat_str, group_dim="XX", match_dim="YY") + res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") + res_str_case = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=True + ) + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype -def test_get(dtype): - values = xr.DataArray(["a_b_c", "c_d_e", "f_g_h"]).astype(dtype) + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) - result = values.str[2] - expected = xr.DataArray(["b", "d", "g"]).astype(dtype) - assert_equal(result, expected) - # bounds testing - values = xr.DataArray(["1_2_3_4_5", "6_7_8_9_10", "11_12"]).astype(dtype) +def test_extractall_multi_single_nocase(dtype): + pat_str = r"(\w+)_Xy_(\d*)" + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re, flags=re.I) - # positive index - result = values.str[5] - expected = xr.DataArray(["_", "_", ""]).astype(dtype) - assert_equal(result, expected) + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) - # negative index - result = values.str[-6] - expected = xr.DataArray(["_", "8", ""]).astype(dtype) - assert_equal(result, expected) + expected = xr.DataArray( + [ + [[["a", "0"]], [["ab", "10"]], [["abc", "01"]]], + [[["abcd", ""]], [["", ""]], [["abcdef", "101"]]], + ], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) + res_str = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=False + ) + res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") -def test_get_default(dtype): - # GH4334 - values = xr.DataArray(["a_b", "c", ""]).astype(dtype) + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype - result = values.str.get(2, "default") - expected = xr.DataArray(["b", "default", "default"]).astype(dtype) - assert_equal(result, expected) + assert_equal(res_str, expected) + assert_equal(res_re, expected) -def test_encode_decode(): - data = xr.DataArray(["a", "b", "a\xe4"]) - encoded = data.str.encode("utf-8") - decoded = encoded.str.decode("utf-8") - assert_equal(data, decoded) +def test_extractall_multi_multi_case(dtype): + pat_str = r"(\w+)_Xy_(\d*)" + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re) + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) -def test_encode_decode_errors(): - encodeBase = xr.DataArray(["a", "b", "a\x9d"]) + expected = xr.DataArray( + [ + [ + [["a", "0"], ["", ""], ["", ""]], + [["bab", "110"], ["baab", "1100"], ["", ""]], + [["abc", "01"], ["cbc", "2210"], ["", ""]], + ], + [ + [["abcd", ""], ["dcd", "33210"], ["dccd", "332210"]], + [["", ""], ["", ""], ["", ""]], + [["abcdef", "101"], ["fef", "5543210"], ["", ""]], + ], + ], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) - msg = ( - r"'charmap' codec can't encode character '\\x9d' in position 1:" - " character maps to " + res_str = value.str.extractall(pat=pat_str, group_dim="XX", match_dim="YY") + res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") + res_str_case = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=True ) - with pytest.raises(UnicodeEncodeError, match=msg): - encodeBase.str.encode("cp1252") - f = lambda x: x.encode("cp1252", "ignore") - result = encodeBase.str.encode("cp1252", "ignore") - expected = xr.DataArray([f(x) for x in encodeBase.values.tolist()]) - assert_equal(result, expected) + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype - decodeBase = xr.DataArray([b"a", b"b", b"a\x9d"]) + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) - msg = ( - "'charmap' codec can't decode byte 0x9d in position 1:" - " character maps to " + +def test_extractall_multi_multi_nocase(dtype): + pat_str = r"(\w+)_Xy_(\d*)" + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re, flags=re.I) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + expected = xr.DataArray( + [ + [ + [["a", "0"], ["", ""], ["", ""]], + [["ab", "10"], ["bab", "110"], ["baab", "1100"]], + [["abc", "01"], ["cbc", "2210"], ["", ""]], + ], + [ + [["abcd", ""], ["dcd", "33210"], ["dccd", "332210"]], + [["", ""], ["", ""], ["", ""]], + [["abcdef", "101"], ["fef", "5543210"], ["", ""]], + ], + ], + dims=["X", "Y", "XX", "YY"], + ).astype(dtype) + + res_str = value.str.extractall( + pat=pat_str, group_dim="XX", match_dim="YY", case=False ) - with pytest.raises(UnicodeDecodeError, match=msg): - decodeBase.str.decode("cp1252") + res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") - f = lambda x: x.decode("cp1252", "ignore") - result = decodeBase.str.decode("cp1252", "ignore") - expected = xr.DataArray([f(x) for x in decodeBase.values.tolist()]) - assert_equal(result, expected) + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + + +def test_extractall_broadcast(dtype): + value = xr.DataArray( + ["a_Xy_0", "ab_xY_10", "abc_Xy_01"], + dims=["X"], + ).astype(dtype) + + pat_str = xr.DataArray( + [r"(\w+)_Xy_(\d*)", r"(\w+)_xY_(\d*)"], + dims=["Y"], + ).astype(dtype) + pat_re = value.str._re_compile(pat=pat_str) + + expected = [ + [[["a", "0"]], [["", ""]]], + [[["", ""]], [["ab", "10"]]], + [[["abc", "01"]], [["", ""]]], + ] + expected = xr.DataArray(expected, dims=["X", "Y", "ZX", "ZY"]).astype(dtype) + + res_str = value.str.extractall(pat=pat_str, group_dim="ZX", match_dim="ZY") + res_re = value.str.extractall(pat=pat_re, group_dim="ZX", match_dim="ZY") + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + + +def test_findall_single_single_case(dtype): + pat_str = r"(\w+)_Xy_\d*" + pat_re = re.compile(dtype(pat_str)) + + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) + + expected = [[["a"], [], ["abc"]], [["abcd"], [], ["abcdef"]]] + expected = [[[dtype(x) for x in y] for y in z] for z in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str) + res_re = value.str.findall(pat=pat_re) + res_str_case = value.str.findall(pat=pat_str, case=True) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) + + +def test_findall_single_single_nocase(dtype): + pat_str = r"(\w+)_Xy_\d*" + pat_re = re.compile(dtype(pat_str), flags=re.I) + + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) + + expected = [[["a"], ["ab"], ["abc"]], [["abcd"], [], ["abcdef"]]] + expected = [[[dtype(x) for x in y] for y in z] for z in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str, case=False) + res_re = value.str.findall(pat=pat_re) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + + +def test_findall_single_multi_case(dtype): + pat_str = r"(\w+)_Xy_\d*" + pat_re = re.compile(dtype(pat_str)) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + expected = [ + [["a"], ["bab", "baab"], ["abc", "cbc"]], + [ + ["abcd", "dcd", "dccd"], + [], + ["abcdef", "fef"], + ], + ] + expected = [[[dtype(x) for x in y] for y in z] for z in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str) + res_re = value.str.findall(pat=pat_re) + res_str_case = value.str.findall(pat=pat_str, case=True) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) + + +def test_findall_single_multi_nocase(dtype): + pat_str = r"(\w+)_Xy_\d*" + pat_re = re.compile(dtype(pat_str), flags=re.I) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + expected = [ + [ + ["a"], + ["ab", "bab", "baab"], + ["abc", "cbc"], + ], + [ + ["abcd", "dcd", "dccd"], + [], + ["abcdef", "fef"], + ], + ] + expected = [[[dtype(x) for x in y] for y in z] for z in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str, case=False) + res_re = value.str.findall(pat=pat_re) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + + +def test_findall_multi_single_case(dtype): + pat_str = r"(\w+)_Xy_(\d*)" + pat_re = re.compile(dtype(pat_str)) + + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) + + expected = [ + [[["a", "0"]], [], [["abc", "01"]]], + [[["abcd", ""]], [], [["abcdef", "101"]]], + ] + expected = [[[tuple(dtype(x) for x in y) for y in z] for z in w] for w in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str) + res_re = value.str.findall(pat=pat_re) + res_str_case = value.str.findall(pat=pat_str, case=True) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) + + +def test_findall_multi_single_nocase(dtype): + pat_str = r"(\w+)_Xy_(\d*)" + pat_re = re.compile(dtype(pat_str), flags=re.I) + + value = xr.DataArray( + [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], + dims=["X", "Y"], + ).astype(dtype) + + expected = [ + [[["a", "0"]], [["ab", "10"]], [["abc", "01"]]], + [[["abcd", ""]], [], [["abcdef", "101"]]], + ] + expected = [[[tuple(dtype(x) for x in y) for y in z] for z in w] for w in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str, case=False) + res_re = value.str.findall(pat=pat_re) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + + +def test_findall_multi_multi_case(dtype): + pat_str = r"(\w+)_Xy_(\d*)" + pat_re = re.compile(dtype(pat_str)) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + expected = [ + [ + [["a", "0"]], + [["bab", "110"], ["baab", "1100"]], + [["abc", "01"], ["cbc", "2210"]], + ], + [ + [["abcd", ""], ["dcd", "33210"], ["dccd", "332210"]], + [], + [["abcdef", "101"], ["fef", "5543210"]], + ], + ] + expected = [[[tuple(dtype(x) for x in y) for y in z] for z in w] for w in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str) + res_re = value.str.findall(pat=pat_re) + res_str_case = value.str.findall(pat=pat_str, case=True) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) + + +def test_findall_multi_multi_nocase(dtype): + pat_str = r"(\w+)_Xy_(\d*)" + pat_re = re.compile(dtype(pat_str), flags=re.I) + + value = xr.DataArray( + [ + ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], + [ + "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", + "", + "abcdef_Xy_101-fef_Xy_5543210", + ], + ], + dims=["X", "Y"], + ).astype(dtype) + + expected = [ + [ + [["a", "0"]], + [["ab", "10"], ["bab", "110"], ["baab", "1100"]], + [["abc", "01"], ["cbc", "2210"]], + ], + [ + [["abcd", ""], ["dcd", "33210"], ["dccd", "332210"]], + [], + [["abcdef", "101"], ["fef", "5543210"]], + ], + ] + expected = [[[tuple(dtype(x) for x in y) for y in z] for z in w] for w in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str, case=False) + res_re = value.str.findall(pat=pat_re) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + + +def test_findall_broadcast(dtype): + value = xr.DataArray( + ["a_Xy_0", "ab_xY_10", "abc_Xy_01"], + dims=["X"], + ).astype(dtype) + + pat_str = xr.DataArray( + [r"(\w+)_Xy_\d*", r"\w+_Xy_(\d*)"], + dims=["Y"], + ).astype(dtype) + pat_re = value.str._re_compile(pat=pat_str) + + expected = [[["a"], ["0"]], [[], []], [["abc"], ["01"]]] + expected = [[[dtype(x) for x in y] for y in z] for z in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str) + res_re = value.str.findall(pat=pat_re) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + + +def test_repeat(dtype): + values = xr.DataArray(["a", "b", "c", "d"]).astype(dtype) + + result = values.str.repeat(3) + result_mul = values.str * 3 + + expected = xr.DataArray(["aaa", "bbb", "ccc", "ddd"]).astype(dtype) + + assert result.dtype == expected.dtype + assert result_mul.dtype == expected.dtype + + assert_equal(result_mul, expected) + assert_equal(result, expected) + + +def test_repeat_broadcast(dtype): + values = xr.DataArray(["a", "b", "c", "d"], dims=["X"]).astype(dtype) + reps = xr.DataArray([3, 4], dims=["Y"]) + + result = values.str.repeat(reps) + result_mul = values.str * reps + + expected = xr.DataArray( + [["aaa", "aaaa"], ["bbb", "bbbb"], ["ccc", "cccc"], ["ddd", "dddd"]], + dims=["X", "Y"], + ).astype(dtype) + + assert result.dtype == expected.dtype + assert result_mul.dtype == expected.dtype + + assert_equal(result_mul, expected) + assert_equal(result, expected) + + +def test_match(dtype): + values = xr.DataArray(["fooBAD__barBAD", "foo"]).astype(dtype) + + # New match behavior introduced in 0.13 + pat = values.dtype.type(".*(BAD[_]+).*(BAD)") + result = values.str.match(pat) + expected = xr.DataArray([True, False]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.match(re.compile(pat)) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # Case-sensitive + pat = values.dtype.type(".*BAD[_]+.*BAD") + result = values.str.match(pat) + expected = xr.DataArray([True, False]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.match(re.compile(pat)) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # Case-insensitive + pat = values.dtype.type(".*bAd[_]+.*bad") + result = values.str.match(pat, case=False) + expected = xr.DataArray([True, False]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.match(re.compile(pat, flags=re.IGNORECASE)) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_empty_str_methods(): + empty = xr.DataArray(np.empty(shape=(0,), dtype="U")) + empty_str = empty + empty_int = xr.DataArray(np.empty(shape=(0,), dtype=int)) + empty_bool = xr.DataArray(np.empty(shape=(0,), dtype=bool)) + empty_bytes = xr.DataArray(np.empty(shape=(0,), dtype="S")) + + # TODO: Determine why U and S dtype sizes don't match and figure + # out a reliable way to predict what they should be + + assert empty_bool.dtype == empty.str.contains("a").dtype + assert empty_bool.dtype == empty.str.endswith("a").dtype + assert empty_bool.dtype == empty.str.match("^a").dtype + assert empty_bool.dtype == empty.str.startswith("a").dtype + assert empty_bool.dtype == empty.str.isalnum().dtype + assert empty_bool.dtype == empty.str.isalpha().dtype + assert empty_bool.dtype == empty.str.isdecimal().dtype + assert empty_bool.dtype == empty.str.isdigit().dtype + assert empty_bool.dtype == empty.str.islower().dtype + assert empty_bool.dtype == empty.str.isnumeric().dtype + assert empty_bool.dtype == empty.str.isspace().dtype + assert empty_bool.dtype == empty.str.istitle().dtype + assert empty_bool.dtype == empty.str.isupper().dtype + assert empty_bytes.dtype.kind == empty.str.encode("ascii").dtype.kind + assert empty_int.dtype.kind == empty.str.count("a").dtype.kind + assert empty_int.dtype.kind == empty.str.find("a").dtype.kind + assert empty_int.dtype.kind == empty.str.len().dtype.kind + assert empty_int.dtype.kind == empty.str.rfind("a").dtype.kind + assert empty_str.dtype.kind == empty.str.capitalize().dtype.kind + assert empty_str.dtype.kind == empty.str.center(42).dtype.kind + assert empty_str.dtype.kind == empty.str.get(0).dtype.kind + assert empty_str.dtype.kind == empty.str.lower().dtype.kind + assert empty_str.dtype.kind == empty.str.lstrip().dtype.kind + assert empty_str.dtype.kind == empty.str.pad(42).dtype.kind + assert empty_str.dtype.kind == empty.str.repeat(3).dtype.kind + assert empty_str.dtype.kind == empty.str.rstrip().dtype.kind + assert empty_str.dtype.kind == empty.str.slice(step=1).dtype.kind + assert empty_str.dtype.kind == empty.str.slice(stop=1).dtype.kind + assert empty_str.dtype.kind == empty.str.strip().dtype.kind + assert empty_str.dtype.kind == empty.str.swapcase().dtype.kind + assert empty_str.dtype.kind == empty.str.title().dtype.kind + assert empty_str.dtype.kind == empty.str.upper().dtype.kind + assert empty_str.dtype.kind == empty.str.wrap(42).dtype.kind + assert empty_str.dtype.kind == empty_bytes.str.decode("ascii").dtype.kind + + assert_equal(empty_bool, empty.str.contains("a")) + assert_equal(empty_bool, empty.str.endswith("a")) + assert_equal(empty_bool, empty.str.match("^a")) + assert_equal(empty_bool, empty.str.startswith("a")) + assert_equal(empty_bool, empty.str.isalnum()) + assert_equal(empty_bool, empty.str.isalpha()) + assert_equal(empty_bool, empty.str.isdecimal()) + assert_equal(empty_bool, empty.str.isdigit()) + assert_equal(empty_bool, empty.str.islower()) + assert_equal(empty_bool, empty.str.isnumeric()) + assert_equal(empty_bool, empty.str.isspace()) + assert_equal(empty_bool, empty.str.istitle()) + assert_equal(empty_bool, empty.str.isupper()) + assert_equal(empty_bytes, empty.str.encode("ascii")) + assert_equal(empty_int, empty.str.count("a")) + assert_equal(empty_int, empty.str.find("a")) + assert_equal(empty_int, empty.str.len()) + assert_equal(empty_int, empty.str.rfind("a")) + assert_equal(empty_str, empty.str.capitalize()) + assert_equal(empty_str, empty.str.center(42)) + assert_equal(empty_str, empty.str.get(0)) + assert_equal(empty_str, empty.str.lower()) + assert_equal(empty_str, empty.str.lstrip()) + assert_equal(empty_str, empty.str.pad(42)) + assert_equal(empty_str, empty.str.repeat(3)) + assert_equal(empty_str, empty.str.replace("a", "b")) + assert_equal(empty_str, empty.str.rstrip()) + assert_equal(empty_str, empty.str.slice(step=1)) + assert_equal(empty_str, empty.str.slice(stop=1)) + assert_equal(empty_str, empty.str.strip()) + assert_equal(empty_str, empty.str.swapcase()) + assert_equal(empty_str, empty.str.title()) + assert_equal(empty_str, empty.str.upper()) + assert_equal(empty_str, empty.str.wrap(42)) + assert_equal(empty_str, empty_bytes.str.decode("ascii")) + + table = str.maketrans("a", "b") + assert empty_str.dtype.kind == empty.str.translate(table).dtype.kind + assert_equal(empty_str, empty.str.translate(table)) + + +def test_ismethods(dtype): + values = ["A", "b", "Xy", "4", "3A", "", "TT", "55", "-", " "] + + exp_alnum = [True, True, True, True, True, False, True, True, False, False] + exp_alpha = [True, True, True, False, False, False, True, False, False, False] + exp_digit = [False, False, False, True, False, False, False, True, False, False] + exp_space = [False, False, False, False, False, False, False, False, False, True] + exp_lower = [False, True, False, False, False, False, False, False, False, False] + exp_upper = [True, False, False, False, True, False, True, False, False, False] + exp_title = [True, False, True, False, True, False, False, False, False, False] + + values = xr.DataArray(values).astype(dtype) + + exp_alnum = xr.DataArray(exp_alnum) + exp_alpha = xr.DataArray(exp_alpha) + exp_digit = xr.DataArray(exp_digit) + exp_space = xr.DataArray(exp_space) + exp_lower = xr.DataArray(exp_lower) + exp_upper = xr.DataArray(exp_upper) + exp_title = xr.DataArray(exp_title) + + res_alnum = values.str.isalnum() + res_alpha = values.str.isalpha() + res_digit = values.str.isdigit() + res_lower = values.str.islower() + res_space = values.str.isspace() + res_title = values.str.istitle() + res_upper = values.str.isupper() + + assert res_alnum.dtype == exp_alnum.dtype + assert res_alpha.dtype == exp_alpha.dtype + assert res_digit.dtype == exp_digit.dtype + assert res_lower.dtype == exp_lower.dtype + assert res_space.dtype == exp_space.dtype + assert res_title.dtype == exp_title.dtype + assert res_upper.dtype == exp_upper.dtype + + assert_equal(res_alnum, exp_alnum) + assert_equal(res_alpha, exp_alpha) + assert_equal(res_digit, exp_digit) + assert_equal(res_lower, exp_lower) + assert_equal(res_space, exp_space) + assert_equal(res_title, exp_title) + assert_equal(res_upper, exp_upper) + + +def test_isnumeric(): + # 0x00bc: ¼ VULGAR FRACTION ONE QUARTER + # 0x2605: ★ not number + # 0x1378: ፸ ETHIOPIC NUMBER SEVENTY + # 0xFF13: 3 Em 3 + values = ["A", "3", "¼", "★", "፸", "3", "four"] + exp_numeric = [False, True, True, False, True, True, False] + exp_decimal = [False, True, False, False, False, True, False] + + values = xr.DataArray(values) + exp_numeric = xr.DataArray(exp_numeric) + exp_decimal = xr.DataArray(exp_decimal) + + res_numeric = values.str.isnumeric() + res_decimal = values.str.isdecimal() + + assert res_numeric.dtype == exp_numeric.dtype + assert res_decimal.dtype == exp_decimal.dtype + + assert_equal(res_numeric, exp_numeric) + assert_equal(res_decimal, exp_decimal) + + +def test_len(dtype): + values = ["foo", "fooo", "fooooo", "fooooooo"] + result = xr.DataArray(values).astype(dtype).str.len() + expected = xr.DataArray([len(x) for x in values]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_find(dtype): + values = xr.DataArray(["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF", "XXX"]) + values = values.astype(dtype) + + result_0 = values.str.find("EF") + result_1 = values.str.find("EF", side="left") + expected_0 = xr.DataArray([4, 3, 1, 0, -1]) + expected_1 = xr.DataArray([v.find(dtype("EF")) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + result_0 = values.str.rfind("EF") + result_1 = values.str.find("EF", side="right") + expected_0 = xr.DataArray([4, 5, 7, 4, -1]) + expected_1 = xr.DataArray([v.rfind(dtype("EF")) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + result_0 = values.str.find("EF", 3) + result_1 = values.str.find("EF", 3, side="left") + expected_0 = xr.DataArray([4, 3, 7, 4, -1]) + expected_1 = xr.DataArray([v.find(dtype("EF"), 3) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + result_0 = values.str.rfind("EF", 3) + result_1 = values.str.find("EF", 3, side="right") + expected_0 = xr.DataArray([4, 5, 7, 4, -1]) + expected_1 = xr.DataArray([v.rfind(dtype("EF"), 3) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + result_0 = values.str.find("EF", 3, 6) + result_1 = values.str.find("EF", 3, 6, side="left") + expected_0 = xr.DataArray([4, 3, -1, 4, -1]) + expected_1 = xr.DataArray([v.find(dtype("EF"), 3, 6) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + result_0 = values.str.rfind("EF", 3, 6) + result_1 = values.str.find("EF", 3, 6, side="right") + expected_0 = xr.DataArray([4, 3, -1, 4, -1]) + expected_1 = xr.DataArray([v.rfind(dtype("EF"), 3, 6) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + +def test_find_broadcast(dtype): + values = xr.DataArray( + ["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF", "XXX"], dims=["X"] + ) + values = values.astype(dtype) + sub = xr.DataArray(["EF", "BC", "XX"], dims=["Y"]).astype(dtype) + start = xr.DataArray([0, 7], dims=["Z"]) + end = xr.DataArray([6, 9], dims=["Z"]) + + result_0 = values.str.find(sub, start, end) + result_1 = values.str.find(sub, start, end, side="left") + expected = xr.DataArray( + [ + [[4, -1], [1, -1], [-1, -1]], + [[3, -1], [0, -1], [-1, -1]], + [[1, 7], [-1, -1], [-1, -1]], + [[0, -1], [-1, -1], [-1, -1]], + [[-1, -1], [-1, -1], [0, -1]], + ], + dims=["X", "Y", "Z"], + ) + + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) + + result_0 = values.str.rfind(sub, start, end) + result_1 = values.str.find(sub, start, end, side="right") + expected = xr.DataArray( + [ + [[4, -1], [1, -1], [-1, -1]], + [[3, -1], [0, -1], [-1, -1]], + [[1, 7], [-1, -1], [-1, -1]], + [[4, -1], [-1, -1], [-1, -1]], + [[-1, -1], [-1, -1], [1, -1]], + ], + dims=["X", "Y", "Z"], + ) + + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) + + +def test_index(dtype): + s = xr.DataArray(["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF"]).astype(dtype) + + result_0 = s.str.index("EF") + result_1 = s.str.index("EF", side="left") + expected = xr.DataArray([4, 3, 1, 0]) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) + + result_0 = s.str.rindex("EF") + result_1 = s.str.index("EF", side="right") + expected = xr.DataArray([4, 5, 7, 4]) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) + + result_0 = s.str.index("EF", 3) + result_1 = s.str.index("EF", 3, side="left") + expected = xr.DataArray([4, 3, 7, 4]) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) + + result_0 = s.str.rindex("EF", 3) + result_1 = s.str.index("EF", 3, side="right") + expected = xr.DataArray([4, 5, 7, 4]) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) + + result_0 = s.str.index("E", 4, 8) + result_1 = s.str.index("E", 4, 8, side="left") + expected = xr.DataArray([4, 5, 7, 4]) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) + + result_0 = s.str.rindex("E", 0, 5) + result_1 = s.str.index("E", 0, 5, side="right") + expected = xr.DataArray([4, 3, 1, 4]) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) + + matchtype = "subsection" if dtype == np.bytes_ else "substring" + with pytest.raises(ValueError, match=f"{matchtype} not found"): + s.str.index("DE") + + +def test_index_broadcast(dtype): + values = xr.DataArray( + ["ABCDEFGEFDBCA", "BCDEFEFEFDBC", "DEFBCGHIEFBC", "EFGHBCEFBCBCBCEF"], + dims=["X"], + ) + values = values.astype(dtype) + sub = xr.DataArray(["EF", "BC"], dims=["Y"]).astype(dtype) + start = xr.DataArray([0, 6], dims=["Z"]) + end = xr.DataArray([6, 12], dims=["Z"]) + + result_0 = values.str.index(sub, start, end) + result_1 = values.str.index(sub, start, end, side="left") + expected = xr.DataArray( + [[[4, 7], [1, 10]], [[3, 7], [0, 10]], [[1, 8], [3, 10]], [[0, 6], [4, 8]]], + dims=["X", "Y", "Z"], + ) + + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) + + result_0 = values.str.rindex(sub, start, end) + result_1 = values.str.index(sub, start, end, side="right") + expected = xr.DataArray( + [[[4, 7], [1, 10]], [[3, 7], [0, 10]], [[1, 8], [3, 10]], [[0, 6], [4, 10]]], + dims=["X", "Y", "Z"], + ) + + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) + + +def test_translate(): + values = xr.DataArray(["abcdefg", "abcc", "cdddfg", "cdefggg"]) + table = str.maketrans("abc", "cde") + result = values.str.translate(table) + expected = xr.DataArray(["cdedefg", "cdee", "edddfg", "edefggg"]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_pad_center_ljust_rjust(dtype): + values = xr.DataArray(["a", "b", "c", "eeeee"]).astype(dtype) + + result = values.str.center(5) + expected = xr.DataArray([" a ", " b ", " c ", "eeeee"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.pad(5, side="both") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.ljust(5) + expected = xr.DataArray(["a ", "b ", "c ", "eeeee"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.pad(5, side="right") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.rjust(5) + expected = xr.DataArray([" a", " b", " c", "eeeee"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.pad(5, side="left") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_pad_center_ljust_rjust_fillchar(dtype): + values = xr.DataArray(["a", "bb", "cccc", "ddddd", "eeeeee"]).astype(dtype) + + result = values.str.center(5, fillchar="X") + expected = xr.DataArray(["XXaXX", "XXbbX", "Xcccc", "ddddd", "eeeeee"]).astype( + dtype + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.pad(5, side="both", fillchar="X") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.ljust(5, fillchar="X") + expected = xr.DataArray(["aXXXX", "bbXXX", "ccccX", "ddddd", "eeeeee"]).astype( + dtype + ) + assert result.dtype == expected.dtype + assert_equal(result, expected.astype(dtype)) + result = values.str.pad(5, side="right", fillchar="X") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.rjust(5, fillchar="X") + expected = xr.DataArray(["XXXXa", "XXXbb", "Xcccc", "ddddd", "eeeeee"]).astype( + dtype + ) + assert result.dtype == expected.dtype + assert_equal(result, expected.astype(dtype)) + result = values.str.pad(5, side="left", fillchar="X") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # If fillchar is not a charatter, normal str raises TypeError + # 'aaa'.ljust(5, 'XY') + # TypeError: must be char, not str + template = "fillchar must be a character, not {dtype}" + + with pytest.raises(TypeError, match=template.format(dtype="str")): + values.str.center(5, fillchar="XY") + + with pytest.raises(TypeError, match=template.format(dtype="str")): + values.str.ljust(5, fillchar="XY") + + with pytest.raises(TypeError, match=template.format(dtype="str")): + values.str.rjust(5, fillchar="XY") + + with pytest.raises(TypeError, match=template.format(dtype="str")): + values.str.pad(5, fillchar="XY") + + +def test_pad_center_ljust_rjust_broadcast(dtype): + values = xr.DataArray(["a", "bb", "cccc", "ddddd", "eeeeee"], dims="X").astype( + dtype + ) + width = xr.DataArray([5, 4], dims="Y") + fillchar = xr.DataArray(["X", "#"], dims="Y").astype(dtype) + + result = values.str.center(width, fillchar=fillchar) + expected = xr.DataArray( + [ + ["XXaXX", "#a##"], + ["XXbbX", "#bb#"], + ["Xcccc", "cccc"], + ["ddddd", "ddddd"], + ["eeeeee", "eeeeee"], + ], + dims=["X", "Y"], + ).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.pad(width, side="both", fillchar=fillchar) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.ljust(width, fillchar=fillchar) + expected = xr.DataArray( + [ + ["aXXXX", "a###"], + ["bbXXX", "bb##"], + ["ccccX", "cccc"], + ["ddddd", "ddddd"], + ["eeeeee", "eeeeee"], + ], + dims=["X", "Y"], + ).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected.astype(dtype)) + result = values.str.pad(width, side="right", fillchar=fillchar) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.rjust(width, fillchar=fillchar) + expected = xr.DataArray( + [ + ["XXXXa", "###a"], + ["XXXbb", "##bb"], + ["Xcccc", "cccc"], + ["ddddd", "ddddd"], + ["eeeeee", "eeeeee"], + ], + dims=["X", "Y"], + ).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected.astype(dtype)) + result = values.str.pad(width, side="left", fillchar=fillchar) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_zfill(dtype): + values = xr.DataArray(["1", "22", "aaa", "333", "45678"]).astype(dtype) + + result = values.str.zfill(5) + expected = xr.DataArray(["00001", "00022", "00aaa", "00333", "45678"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.zfill(3) + expected = xr.DataArray(["001", "022", "aaa", "333", "45678"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_zfill_broadcast(dtype): + values = xr.DataArray(["1", "22", "aaa", "333", "45678"]).astype(dtype) + width = np.array([4, 5, 0, 3, 8]) + + result = values.str.zfill(width) + expected = xr.DataArray(["0001", "00022", "aaa", "333", "00045678"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_slice(dtype): + arr = xr.DataArray(["aafootwo", "aabartwo", "aabazqux"]).astype(dtype) + + result = arr.str.slice(2, 5) + exp = xr.DataArray(["foo", "bar", "baz"]).astype(dtype) + assert result.dtype == exp.dtype + assert_equal(result, exp) + + for start, stop, step in [(0, 3, -1), (None, None, -1), (3, 10, 2), (3, 0, -1)]: + try: + result = arr.str[start:stop:step] + expected = xr.DataArray([s[start:stop:step] for s in arr.values]) + assert_equal(result, expected.astype(dtype)) + except IndexError: + print(f"failed on {start}:{stop}:{step}") + raise + + +def test_slice_broadcast(dtype): + arr = xr.DataArray(["aafootwo", "aabartwo", "aabazqux"]).astype(dtype) + start = xr.DataArray([1, 2, 3]) + stop = 5 + + result = arr.str.slice(start=start, stop=stop) + exp = xr.DataArray(["afoo", "bar", "az"]).astype(dtype) + assert result.dtype == exp.dtype + assert_equal(result, exp) + + +def test_slice_replace(dtype): + da = lambda x: xr.DataArray(x).astype(dtype) + values = da(["short", "a bit longer", "evenlongerthanthat", ""]) + + expected = da(["shrt", "a it longer", "evnlongerthanthat", ""]) + result = values.str.slice_replace(2, 3) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + expected = da(["shzrt", "a zit longer", "evznlongerthanthat", "z"]) + result = values.str.slice_replace(2, 3, "z") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + expected = da(["shzort", "a zbit longer", "evzenlongerthanthat", "z"]) + result = values.str.slice_replace(2, 2, "z") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + expected = da(["shzort", "a zbit longer", "evzenlongerthanthat", "z"]) + result = values.str.slice_replace(2, 1, "z") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + expected = da(["shorz", "a bit longez", "evenlongerthanthaz", "z"]) + result = values.str.slice_replace(-1, None, "z") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + expected = da(["zrt", "zer", "zat", "z"]) + result = values.str.slice_replace(None, -2, "z") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + expected = da(["shortz", "a bit znger", "evenlozerthanthat", "z"]) + result = values.str.slice_replace(6, 8, "z") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + expected = da(["zrt", "a zit longer", "evenlongzerthanthat", "z"]) + result = values.str.slice_replace(-10, 3, "z") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_slice_replace_broadcast(dtype): + values = xr.DataArray(["short", "a bit longer", "evenlongerthanthat", ""]).astype( + dtype + ) + start = 2 + stop = np.array([4, 5, None, 7]) + repl = "test" + + expected = xr.DataArray(["shtestt", "a test longer", "evtest", "test"]).astype( + dtype + ) + result = values.str.slice_replace(start, stop, repl) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_strip_lstrip_rstrip(dtype): + values = xr.DataArray([" aa ", " bb \n", "cc "]).astype(dtype) + + result = values.str.strip() + expected = xr.DataArray(["aa", "bb", "cc"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.lstrip() + expected = xr.DataArray(["aa ", "bb \n", "cc "]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.rstrip() + expected = xr.DataArray([" aa", " bb", "cc"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_strip_lstrip_rstrip_args(dtype): + values = xr.DataArray(["xxABCxx", "xx BNSD", "LDFJH xx"]).astype(dtype) + + result = values.str.strip("x") + expected = xr.DataArray(["ABC", " BNSD", "LDFJH "]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.lstrip("x") + expected = xr.DataArray(["ABCxx", " BNSD", "LDFJH xx"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.rstrip("x") + expected = xr.DataArray(["xxABC", "xx BNSD", "LDFJH "]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_strip_lstrip_rstrip_broadcast(dtype): + values = xr.DataArray(["xxABCxx", "yy BNSD", "LDFJH zz"]).astype(dtype) + to_strip = xr.DataArray(["x", "y", "z"]).astype(dtype) + + result = values.str.strip(to_strip) + expected = xr.DataArray(["ABC", " BNSD", "LDFJH "]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.lstrip(to_strip) + expected = xr.DataArray(["ABCxx", " BNSD", "LDFJH zz"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.rstrip(to_strip) + expected = xr.DataArray(["xxABC", "yy BNSD", "LDFJH "]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_wrap(): + # test values are: two words less than width, two words equal to width, + # two words greater than width, one word less than width, one word + # equal to width, one word greater than width, multiple tokens with + # trailing whitespace equal to width + values = xr.DataArray( + [ + "hello world", + "hello world!", + "hello world!!", + "abcdefabcde", + "abcdefabcdef", + "abcdefabcdefa", + "ab ab ab ab ", + "ab ab ab ab a", + "\t", + ] + ) + + # expected values + expected = xr.DataArray( + [ + "hello world", + "hello world!", + "hello\nworld!!", + "abcdefabcde", + "abcdefabcdef", + "abcdefabcdef\na", + "ab ab ab ab", + "ab ab ab ab\na", + "", + ] + ) + + result = values.str.wrap(12, break_long_words=True) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # test with pre and post whitespace (non-unicode), NaN, and non-ascii + # Unicode + values = xr.DataArray([" pre ", "\xac\u20ac\U00008000 abadcafe"]) + expected = xr.DataArray([" pre", "\xac\u20ac\U00008000 ab\nadcafe"]) + result = values.str.wrap(6) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_wrap_kwargs_passed(): + # GH4334 + + values = xr.DataArray(" hello world ") + + result = values.str.wrap(7) + expected = xr.DataArray(" hello\nworld") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.wrap(7, drop_whitespace=False) + expected = xr.DataArray(" hello\n world\n ") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_get(dtype): + values = xr.DataArray(["a_b_c", "c_d_e", "f_g_h"]).astype(dtype) + + result = values.str[2] + expected = xr.DataArray(["b", "d", "g"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # bounds testing + values = xr.DataArray(["1_2_3_4_5", "6_7_8_9_10", "11_12"]).astype(dtype) + + # positive index + result = values.str[5] + expected = xr.DataArray(["_", "_", ""]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # negative index + result = values.str[-6] + expected = xr.DataArray(["_", "8", ""]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_get_default(dtype): + # GH4334 + values = xr.DataArray(["a_b", "c", ""]).astype(dtype) + + result = values.str.get(2, "default") + expected = xr.DataArray(["b", "default", "default"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_get_broadcast(dtype): + values = xr.DataArray(["a_b_c", "c_d_e", "f_g_h"], dims=["X"]).astype(dtype) + inds = xr.DataArray([0, 2], dims=["Y"]) + + result = values.str.get(inds) + expected = xr.DataArray( + [["a", "b"], ["c", "d"], ["f", "g"]], dims=["X", "Y"] + ).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_encode_decode(): + data = xr.DataArray(["a", "b", "a\xe4"]) + encoded = data.str.encode("utf-8") + decoded = encoded.str.decode("utf-8") + assert data.dtype == decoded.dtype + assert_equal(data, decoded) + + +def test_encode_decode_errors(): + encodeBase = xr.DataArray(["a", "b", "a\x9d"]) + + msg = ( + r"'charmap' codec can't encode character '\\x9d' in position 1:" + " character maps to " + ) + with pytest.raises(UnicodeEncodeError, match=msg): + encodeBase.str.encode("cp1252") + + f = lambda x: x.encode("cp1252", "ignore") + result = encodeBase.str.encode("cp1252", "ignore") + expected = xr.DataArray([f(x) for x in encodeBase.values.tolist()]) + + assert result.dtype == expected.dtype + assert_equal(result, expected) + + decodeBase = xr.DataArray([b"a", b"b", b"a\x9d"]) + + msg = ( + "'charmap' codec can't decode byte 0x9d in position 1:" + " character maps to " + ) + with pytest.raises(UnicodeDecodeError, match=msg): + decodeBase.str.decode("cp1252") + + f = lambda x: x.decode("cp1252", "ignore") + result = decodeBase.str.decode("cp1252", "ignore") + expected = xr.DataArray([f(x) for x in decodeBase.values.tolist()]) + + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_partition_whitespace(dtype): + values = xr.DataArray( + [ + ["abc def", "spam eggs swallow", "red_blue"], + ["test0 test1 test2 test3", "", "abra ka da bra"], + ], + dims=["X", "Y"], + ).astype(dtype) + + exp_part_dim = [ + [ + ["abc", " ", "def"], + ["spam", " ", "eggs swallow"], + ["red_blue", "", ""], + ], + [ + ["test0", " ", "test1 test2 test3"], + ["", "", ""], + ["abra", " ", "ka da bra"], + ], + ] + + exp_rpart_dim = [ + [ + ["abc", " ", "def"], + ["spam eggs", " ", "swallow"], + ["", "", "red_blue"], + ], + [ + ["test0 test1 test2", " ", "test3"], + ["", "", ""], + ["abra ka da", " ", "bra"], + ], + ] + + exp_part_dim = xr.DataArray(exp_part_dim, dims=["X", "Y", "ZZ"]).astype(dtype) + exp_rpart_dim = xr.DataArray(exp_rpart_dim, dims=["X", "Y", "ZZ"]).astype(dtype) + + res_part_dim = values.str.partition(dim="ZZ") + res_rpart_dim = values.str.rpartition(dim="ZZ") + + assert res_part_dim.dtype == exp_part_dim.dtype + assert res_rpart_dim.dtype == exp_rpart_dim.dtype + + assert_equal(res_part_dim, exp_part_dim) + assert_equal(res_rpart_dim, exp_rpart_dim) + + +def test_partition_comma(dtype): + values = xr.DataArray( + [ + ["abc, def", "spam, eggs, swallow", "red_blue"], + ["test0, test1, test2, test3", "", "abra, ka, da, bra"], + ], + dims=["X", "Y"], + ).astype(dtype) + + exp_part_dim = [ + [ + ["abc", ", ", "def"], + ["spam", ", ", "eggs, swallow"], + ["red_blue", "", ""], + ], + [ + ["test0", ", ", "test1, test2, test3"], + ["", "", ""], + ["abra", ", ", "ka, da, bra"], + ], + ] + + exp_rpart_dim = [ + [ + ["abc", ", ", "def"], + ["spam, eggs", ", ", "swallow"], + ["", "", "red_blue"], + ], + [ + ["test0, test1, test2", ", ", "test3"], + ["", "", ""], + ["abra, ka, da", ", ", "bra"], + ], + ] + + exp_part_dim = xr.DataArray(exp_part_dim, dims=["X", "Y", "ZZ"]).astype(dtype) + exp_rpart_dim = xr.DataArray(exp_rpart_dim, dims=["X", "Y", "ZZ"]).astype(dtype) + + res_part_dim = values.str.partition(sep=", ", dim="ZZ") + res_rpart_dim = values.str.rpartition(sep=", ", dim="ZZ") + + assert res_part_dim.dtype == exp_part_dim.dtype + assert res_rpart_dim.dtype == exp_rpart_dim.dtype + + assert_equal(res_part_dim, exp_part_dim) + assert_equal(res_rpart_dim, exp_rpart_dim) + + +def test_partition_empty(dtype): + values = xr.DataArray([], dims=["X"]).astype(dtype) + expected = xr.DataArray(np.zeros((0, 0)), dims=["X", "ZZ"]).astype(dtype) + + res = values.str.partition(sep=", ", dim="ZZ") + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_split_whitespace(dtype): + values = xr.DataArray( + [ + ["abc def", "spam\t\teggs\tswallow", "red_blue"], + ["test0\ntest1\ntest2\n\ntest3", "", "abra ka\nda\tbra"], + ], + dims=["X", "Y"], + ).astype(dtype) + + exp_split_dim_full = [ + [ + ["abc", "def", "", ""], + ["spam", "eggs", "swallow", ""], + ["red_blue", "", "", ""], + ], + [ + ["test0", "test1", "test2", "test3"], + ["", "", "", ""], + ["abra", "ka", "da", "bra"], + ], + ] + + exp_rsplit_dim_full = [ + [ + ["", "", "abc", "def"], + ["", "spam", "eggs", "swallow"], + ["", "", "", "red_blue"], + ], + [ + ["test0", "test1", "test2", "test3"], + ["", "", "", ""], + ["abra", "ka", "da", "bra"], + ], + ] + + exp_split_dim_1 = [ + [["abc", "def"], ["spam", "eggs\tswallow"], ["red_blue", ""]], + [["test0", "test1\ntest2\n\ntest3"], ["", ""], ["abra", "ka\nda\tbra"]], + ] + + exp_rsplit_dim_1 = [ + [["abc", "def"], ["spam\t\teggs", "swallow"], ["", "red_blue"]], + [["test0\ntest1\ntest2", "test3"], ["", ""], ["abra ka\nda", "bra"]], + ] + + exp_split_none_full = [ + [["abc", "def"], ["spam", "eggs", "swallow"], ["red_blue"]], + [["test0", "test1", "test2", "test3"], [], ["abra", "ka", "da", "bra"]], + ] + + exp_rsplit_none_full = [ + [["abc", "def"], ["spam", "eggs", "swallow"], ["red_blue"]], + [["test0", "test1", "test2", "test3"], [], ["abra", "ka", "da", "bra"]], + ] + + exp_split_none_1 = [ + [["abc", "def"], ["spam", "eggs\tswallow"], ["red_blue"]], + [["test0", "test1\ntest2\n\ntest3"], [], ["abra", "ka\nda\tbra"]], + ] + + exp_rsplit_none_1 = [ + [["abc", "def"], ["spam\t\teggs", "swallow"], ["red_blue"]], + [["test0\ntest1\ntest2", "test3"], [], ["abra ka\nda", "bra"]], + ] + + exp_split_none_full = [ + [[dtype(x) for x in y] for y in z] for z in exp_split_none_full + ] + exp_rsplit_none_full = [ + [[dtype(x) for x in y] for y in z] for z in exp_rsplit_none_full + ] + exp_split_none_1 = [[[dtype(x) for x in y] for y in z] for z in exp_split_none_1] + exp_rsplit_none_1 = [[[dtype(x) for x in y] for y in z] for z in exp_rsplit_none_1] + + exp_split_none_full = np.array(exp_split_none_full, dtype=np.object_) + exp_rsplit_none_full = np.array(exp_rsplit_none_full, dtype=np.object_) + exp_split_none_1 = np.array(exp_split_none_1, dtype=np.object_) + exp_rsplit_none_1 = np.array(exp_rsplit_none_1, dtype=np.object_) + + exp_split_dim_full = xr.DataArray(exp_split_dim_full, dims=["X", "Y", "ZZ"]).astype( + dtype + ) + exp_rsplit_dim_full = xr.DataArray( + exp_rsplit_dim_full, dims=["X", "Y", "ZZ"] + ).astype(dtype) + exp_split_dim_1 = xr.DataArray(exp_split_dim_1, dims=["X", "Y", "ZZ"]).astype(dtype) + exp_rsplit_dim_1 = xr.DataArray(exp_rsplit_dim_1, dims=["X", "Y", "ZZ"]).astype( + dtype + ) + + exp_split_none_full = xr.DataArray(exp_split_none_full, dims=["X", "Y"]) + exp_rsplit_none_full = xr.DataArray(exp_rsplit_none_full, dims=["X", "Y"]) + exp_split_none_1 = xr.DataArray(exp_split_none_1, dims=["X", "Y"]) + exp_rsplit_none_1 = xr.DataArray(exp_rsplit_none_1, dims=["X", "Y"]) + + res_split_dim_full = values.str.split(dim="ZZ") + res_rsplit_dim_full = values.str.rsplit(dim="ZZ") + res_split_dim_1 = values.str.split(dim="ZZ", maxsplit=1) + res_rsplit_dim_1 = values.str.rsplit(dim="ZZ", maxsplit=1) + res_split_dim_10 = values.str.split(dim="ZZ", maxsplit=10) + res_rsplit_dim_10 = values.str.rsplit(dim="ZZ", maxsplit=10) + + res_split_none_full = values.str.split(dim=None) + res_rsplit_none_full = values.str.rsplit(dim=None) + res_split_none_1 = values.str.split(dim=None, maxsplit=1) + res_rsplit_none_1 = values.str.rsplit(dim=None, maxsplit=1) + res_split_none_10 = values.str.split(dim=None, maxsplit=10) + res_rsplit_none_10 = values.str.rsplit(dim=None, maxsplit=10) + + assert res_split_dim_full.dtype == exp_split_dim_full.dtype + assert res_rsplit_dim_full.dtype == exp_rsplit_dim_full.dtype + assert res_split_dim_1.dtype == exp_split_dim_1.dtype + assert res_rsplit_dim_1.dtype == exp_rsplit_dim_1.dtype + assert res_split_dim_10.dtype == exp_split_dim_full.dtype + assert res_rsplit_dim_10.dtype == exp_rsplit_dim_full.dtype + + assert res_split_none_full.dtype == exp_split_none_full.dtype + assert res_rsplit_none_full.dtype == exp_rsplit_none_full.dtype + assert res_split_none_1.dtype == exp_split_none_1.dtype + assert res_rsplit_none_1.dtype == exp_rsplit_none_1.dtype + assert res_split_none_10.dtype == exp_split_none_full.dtype + assert res_rsplit_none_10.dtype == exp_rsplit_none_full.dtype + + assert_equal(res_split_dim_full, exp_split_dim_full) + assert_equal(res_rsplit_dim_full, exp_rsplit_dim_full) + assert_equal(res_split_dim_1, exp_split_dim_1) + assert_equal(res_rsplit_dim_1, exp_rsplit_dim_1) + assert_equal(res_split_dim_10, exp_split_dim_full) + assert_equal(res_rsplit_dim_10, exp_rsplit_dim_full) + + assert_equal(res_split_none_full, exp_split_none_full) + assert_equal(res_rsplit_none_full, exp_rsplit_none_full) + assert_equal(res_split_none_1, exp_split_none_1) + assert_equal(res_rsplit_none_1, exp_rsplit_none_1) + assert_equal(res_split_none_10, exp_split_none_full) + assert_equal(res_rsplit_none_10, exp_rsplit_none_full) + + +def test_split_comma(dtype): + values = xr.DataArray( + [ + ["abc,def", "spam,,eggs,swallow", "red_blue"], + ["test0,test1,test2,test3", "", "abra,ka,da,bra"], + ], + dims=["X", "Y"], + ).astype(dtype) + + exp_split_dim_full = [ + [ + ["abc", "def", "", ""], + ["spam", "", "eggs", "swallow"], + ["red_blue", "", "", ""], + ], + [ + ["test0", "test1", "test2", "test3"], + ["", "", "", ""], + ["abra", "ka", "da", "bra"], + ], + ] + + exp_rsplit_dim_full = [ + [ + ["", "", "abc", "def"], + ["spam", "", "eggs", "swallow"], + ["", "", "", "red_blue"], + ], + [ + ["test0", "test1", "test2", "test3"], + ["", "", "", ""], + ["abra", "ka", "da", "bra"], + ], + ] + + exp_split_dim_1 = [ + [["abc", "def"], ["spam", ",eggs,swallow"], ["red_blue", ""]], + [["test0", "test1,test2,test3"], ["", ""], ["abra", "ka,da,bra"]], + ] + + exp_rsplit_dim_1 = [ + [["abc", "def"], ["spam,,eggs", "swallow"], ["", "red_blue"]], + [["test0,test1,test2", "test3"], ["", ""], ["abra,ka,da", "bra"]], + ] + + exp_split_none_full = [ + [["abc", "def"], ["spam", "", "eggs", "swallow"], ["red_blue"]], + [["test0", "test1", "test2", "test3"], [""], ["abra", "ka", "da", "bra"]], + ] + + exp_rsplit_none_full = [ + [["abc", "def"], ["spam", "", "eggs", "swallow"], ["red_blue"]], + [["test0", "test1", "test2", "test3"], [""], ["abra", "ka", "da", "bra"]], + ] + + exp_split_none_1 = [ + [["abc", "def"], ["spam", ",eggs,swallow"], ["red_blue"]], + [["test0", "test1,test2,test3"], [""], ["abra", "ka,da,bra"]], + ] + + exp_rsplit_none_1 = [ + [["abc", "def"], ["spam,,eggs", "swallow"], ["red_blue"]], + [["test0,test1,test2", "test3"], [""], ["abra,ka,da", "bra"]], + ] + + exp_split_none_full = [ + [[dtype(x) for x in y] for y in z] for z in exp_split_none_full + ] + exp_rsplit_none_full = [ + [[dtype(x) for x in y] for y in z] for z in exp_rsplit_none_full + ] + exp_split_none_1 = [[[dtype(x) for x in y] for y in z] for z in exp_split_none_1] + exp_rsplit_none_1 = [[[dtype(x) for x in y] for y in z] for z in exp_rsplit_none_1] + + exp_split_none_full = np.array(exp_split_none_full, dtype=np.object_) + exp_rsplit_none_full = np.array(exp_rsplit_none_full, dtype=np.object_) + exp_split_none_1 = np.array(exp_split_none_1, dtype=np.object_) + exp_rsplit_none_1 = np.array(exp_rsplit_none_1, dtype=np.object_) + + exp_split_dim_full = xr.DataArray(exp_split_dim_full, dims=["X", "Y", "ZZ"]).astype( + dtype + ) + exp_rsplit_dim_full = xr.DataArray( + exp_rsplit_dim_full, dims=["X", "Y", "ZZ"] + ).astype(dtype) + exp_split_dim_1 = xr.DataArray(exp_split_dim_1, dims=["X", "Y", "ZZ"]).astype(dtype) + exp_rsplit_dim_1 = xr.DataArray(exp_rsplit_dim_1, dims=["X", "Y", "ZZ"]).astype( + dtype + ) + + exp_split_none_full = xr.DataArray(exp_split_none_full, dims=["X", "Y"]) + exp_rsplit_none_full = xr.DataArray(exp_rsplit_none_full, dims=["X", "Y"]) + exp_split_none_1 = xr.DataArray(exp_split_none_1, dims=["X", "Y"]) + exp_rsplit_none_1 = xr.DataArray(exp_rsplit_none_1, dims=["X", "Y"]) + + res_split_dim_full = values.str.split(sep=",", dim="ZZ") + res_rsplit_dim_full = values.str.rsplit(sep=",", dim="ZZ") + res_split_dim_1 = values.str.split(sep=",", dim="ZZ", maxsplit=1) + res_rsplit_dim_1 = values.str.rsplit(sep=",", dim="ZZ", maxsplit=1) + res_split_dim_10 = values.str.split(sep=",", dim="ZZ", maxsplit=10) + res_rsplit_dim_10 = values.str.rsplit(sep=",", dim="ZZ", maxsplit=10) + + res_split_none_full = values.str.split(sep=",", dim=None) + res_rsplit_none_full = values.str.rsplit(sep=",", dim=None) + res_split_none_1 = values.str.split(sep=",", dim=None, maxsplit=1) + res_rsplit_none_1 = values.str.rsplit(sep=",", dim=None, maxsplit=1) + res_split_none_10 = values.str.split(sep=",", dim=None, maxsplit=10) + res_rsplit_none_10 = values.str.rsplit(sep=",", dim=None, maxsplit=10) + + assert res_split_dim_full.dtype == exp_split_dim_full.dtype + assert res_rsplit_dim_full.dtype == exp_rsplit_dim_full.dtype + assert res_split_dim_1.dtype == exp_split_dim_1.dtype + assert res_rsplit_dim_1.dtype == exp_rsplit_dim_1.dtype + assert res_split_dim_10.dtype == exp_split_dim_full.dtype + assert res_rsplit_dim_10.dtype == exp_rsplit_dim_full.dtype + + assert res_split_none_full.dtype == exp_split_none_full.dtype + assert res_rsplit_none_full.dtype == exp_rsplit_none_full.dtype + assert res_split_none_1.dtype == exp_split_none_1.dtype + assert res_rsplit_none_1.dtype == exp_rsplit_none_1.dtype + assert res_split_none_10.dtype == exp_split_none_full.dtype + assert res_rsplit_none_10.dtype == exp_rsplit_none_full.dtype + + assert_equal(res_split_dim_full, exp_split_dim_full) + assert_equal(res_rsplit_dim_full, exp_rsplit_dim_full) + assert_equal(res_split_dim_1, exp_split_dim_1) + assert_equal(res_rsplit_dim_1, exp_rsplit_dim_1) + assert_equal(res_split_dim_10, exp_split_dim_full) + assert_equal(res_rsplit_dim_10, exp_rsplit_dim_full) + + assert_equal(res_split_none_full, exp_split_none_full) + assert_equal(res_rsplit_none_full, exp_rsplit_none_full) + assert_equal(res_split_none_1, exp_split_none_1) + assert_equal(res_rsplit_none_1, exp_rsplit_none_1) + assert_equal(res_split_none_10, exp_split_none_full) + assert_equal(res_rsplit_none_10, exp_rsplit_none_full) + + +def test_splitters_broadcast(dtype): + values = xr.DataArray( + ["ab cd,de fg", "spam, ,eggs swallow", "red_blue"], + dims=["X"], + ).astype(dtype) + + sep = xr.DataArray( + [" ", ","], + dims=["Y"], + ).astype(dtype) + + expected_left = xr.DataArray( + [ + [["ab", "cd,de fg"], ["ab cd", "de fg"]], + [["spam,", ",eggs swallow"], ["spam", " ,eggs swallow"]], + [["red_blue", ""], ["red_blue", ""]], + ], + dims=["X", "Y", "ZZ"], + ).astype(dtype) + expected_right = xr.DataArray( + [ + [["ab cd,de", "fg"], ["ab cd", "de fg"]], + [["spam, ,eggs", "swallow"], ["spam, ", "eggs swallow"]], + [["", "red_blue"], ["", "red_blue"]], + ], + dims=["X", "Y", "ZZ"], + ).astype(dtype) + + res_left = values.str.split(dim="ZZ", sep=sep, maxsplit=1) + res_right = values.str.rsplit(dim="ZZ", sep=sep, maxsplit=1) + + # assert res_left.dtype == expected_left.dtype + # assert res_right.dtype == expected_right.dtype + + assert_equal(res_left, expected_left) + assert_equal(res_right, expected_right) + + expected_left = xr.DataArray( + [ + [["ab", " ", "cd,de fg"], ["ab cd", ",", "de fg"]], + [["spam,", " ", ",eggs swallow"], ["spam", ",", " ,eggs swallow"]], + [["red_blue", "", ""], ["red_blue", "", ""]], + ], + dims=["X", "Y", "ZZ"], + ).astype(dtype) + expected_right = xr.DataArray( + [ + [["ab", " ", "cd,de fg"], ["ab cd", ",", "de fg"]], + [["spam,", " ", ",eggs swallow"], ["spam", ",", " ,eggs swallow"]], + [["red_blue", "", ""], ["red_blue", "", ""]], + ], + dims=["X", "Y", "ZZ"], + ).astype(dtype) + + res_left = values.str.partition(dim="ZZ", sep=sep) + res_right = values.str.partition(dim="ZZ", sep=sep) + + # assert res_left.dtype == expected_left.dtype + # assert res_right.dtype == expected_right.dtype + + assert_equal(res_left, expected_left) + assert_equal(res_right, expected_right) + + +def test_split_empty(dtype): + values = xr.DataArray([], dims=["X"]).astype(dtype) + expected = xr.DataArray(np.zeros((0, 0)), dims=["X", "ZZ"]).astype(dtype) + + res = values.str.split(sep=", ", dim="ZZ") + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_get_dummies(dtype): + values_line = xr.DataArray( + [["a|ab~abc|abc", "ab", "a||abc|abcd"], ["abcd|ab|a", "abc|ab~abc", "|a"]], + dims=["X", "Y"], + ).astype(dtype) + values_comma = xr.DataArray( + [["a~ab|abc~~abc", "ab", "a~abc~abcd"], ["abcd~ab~a", "abc~ab|abc", "~a"]], + dims=["X", "Y"], + ).astype(dtype) + + vals_line = np.array(["a", "ab", "abc", "abcd", "ab~abc"]).astype(dtype) + vals_comma = np.array(["a", "ab", "abc", "abcd", "ab|abc"]).astype(dtype) + expected = [ + [ + [True, False, True, False, True], + [False, True, False, False, False], + [True, False, True, True, False], + ], + [ + [True, True, False, True, False], + [False, False, True, False, True], + [True, False, False, False, False], + ], + ] + expected = np.array(expected) + expected = xr.DataArray(expected, dims=["X", "Y", "ZZ"]) + targ_line = expected.copy() + targ_comma = expected.copy() + targ_line.coords["ZZ"] = vals_line + targ_comma.coords["ZZ"] = vals_comma + + res_default = values_line.str.get_dummies(dim="ZZ") + res_line = values_line.str.get_dummies(dim="ZZ", sep="|") + res_comma = values_comma.str.get_dummies(dim="ZZ", sep="~") + + assert res_default.dtype == targ_line.dtype + assert res_line.dtype == targ_line.dtype + assert res_comma.dtype == targ_comma.dtype + + assert_equal(res_default, targ_line) + assert_equal(res_line, targ_line) + assert_equal(res_comma, targ_comma) + + +def test_get_dummies_broadcast(dtype): + values = xr.DataArray( + ["x~x|x~x", "x", "x|x~x", "x~x"], + dims=["X"], + ).astype(dtype) + + sep = xr.DataArray( + ["|", "~"], + dims=["Y"], + ).astype(dtype) + + expected = [ + [[False, False, True], [True, True, False]], + [[True, False, False], [True, False, False]], + [[True, False, True], [True, True, False]], + [[False, False, True], [True, False, False]], + ] + expected = np.array(expected) + expected = xr.DataArray(expected, dims=["X", "Y", "ZZ"]) + expected.coords["ZZ"] = np.array(["x", "x|x", "x~x"]).astype(dtype) + + res = values.str.get_dummies(dim="ZZ", sep=sep) + + assert res.dtype == expected.dtype + + assert_equal(res, expected) + + +def test_get_dummies_empty(dtype): + values = xr.DataArray([], dims=["X"]).astype(dtype) + expected = xr.DataArray(np.zeros((0, 0)), dims=["X", "ZZ"]).astype(dtype) + + res = values.str.get_dummies(dim="ZZ") + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_splitters_empty_str(dtype): + values = xr.DataArray( + [["", "", ""], ["", "", ""]], + dims=["X", "Y"], + ).astype(dtype) + + targ_partition_dim = xr.DataArray( + [ + [["", "", ""], ["", "", ""], ["", "", ""]], + [["", "", ""], ["", "", ""], ["", "", ""]], + ], + dims=["X", "Y", "ZZ"], + ).astype(dtype) + + targ_partition_none = [ + [["", "", ""], ["", "", ""], ["", "", ""]], + [["", "", ""], ["", "", ""], ["", "", "", ""]], + ] + targ_partition_none = [ + [[dtype(x) for x in y] for y in z] for z in targ_partition_none + ] + targ_partition_none = np.array(targ_partition_none, dtype=np.object_) + del targ_partition_none[-1, -1][-1] + targ_partition_none = xr.DataArray( + targ_partition_none, + dims=["X", "Y"], + ) + + targ_split_dim = xr.DataArray( + [[[""], [""], [""]], [[""], [""], [""]]], + dims=["X", "Y", "ZZ"], + ).astype(dtype) + targ_split_none = xr.DataArray( + np.array([[[], [], []], [[], [], [""]]], dtype=np.object_), + dims=["X", "Y"], + ) + del targ_split_none.data[-1, -1][-1] + + res_partition_dim = values.str.partition(dim="ZZ") + res_rpartition_dim = values.str.rpartition(dim="ZZ") + res_partition_none = values.str.partition(dim=None) + res_rpartition_none = values.str.rpartition(dim=None) + + res_split_dim = values.str.split(dim="ZZ") + res_rsplit_dim = values.str.rsplit(dim="ZZ") + res_split_none = values.str.split(dim=None) + res_rsplit_none = values.str.rsplit(dim=None) + + res_dummies = values.str.rsplit(dim="ZZ") + + assert res_partition_dim.dtype == targ_partition_dim.dtype + assert res_rpartition_dim.dtype == targ_partition_dim.dtype + assert res_partition_none.dtype == targ_partition_none.dtype + assert res_rpartition_none.dtype == targ_partition_none.dtype + + assert res_split_dim.dtype == targ_split_dim.dtype + assert res_rsplit_dim.dtype == targ_split_dim.dtype + assert res_split_none.dtype == targ_split_none.dtype + assert res_rsplit_none.dtype == targ_split_none.dtype + + assert res_dummies.dtype == targ_split_dim.dtype + + assert_equal(res_partition_dim, targ_partition_dim) + assert_equal(res_rpartition_dim, targ_partition_dim) + assert_equal(res_partition_none, targ_partition_none) + assert_equal(res_rpartition_none, targ_partition_none) + + assert_equal(res_split_dim, targ_split_dim) + assert_equal(res_rsplit_dim, targ_split_dim) + assert_equal(res_split_none, targ_split_none) + assert_equal(res_rsplit_none, targ_split_none) + + assert_equal(res_dummies, targ_split_dim) + + +def test_cat_str(dtype): + values_1 = xr.DataArray( + [["a", "bb", "cccc"], ["ddddd", "eeee", "fff"]], + dims=["X", "Y"], + ).astype(dtype) + values_2 = "111" + + targ_blank = xr.DataArray( + [["a111", "bb111", "cccc111"], ["ddddd111", "eeee111", "fff111"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_space = xr.DataArray( + [["a 111", "bb 111", "cccc 111"], ["ddddd 111", "eeee 111", "fff 111"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_bars = xr.DataArray( + [["a||111", "bb||111", "cccc||111"], ["ddddd||111", "eeee||111", "fff||111"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_comma = xr.DataArray( + [["a, 111", "bb, 111", "cccc, 111"], ["ddddd, 111", "eeee, 111", "fff, 111"]], + dims=["X", "Y"], + ).astype(dtype) + + res_blank = values_1.str.cat(values_2) + res_add = values_1.str + values_2 + res_space = values_1.str.cat(values_2, sep=" ") + res_bars = values_1.str.cat(values_2, sep="||") + res_comma = values_1.str.cat(values_2, sep=", ") + + assert res_blank.dtype == targ_blank.dtype + assert res_add.dtype == targ_blank.dtype + assert res_space.dtype == targ_space.dtype + assert res_bars.dtype == targ_bars.dtype + assert res_comma.dtype == targ_comma.dtype + + assert_equal(res_blank, targ_blank) + assert_equal(res_add, targ_blank) + assert_equal(res_space, targ_space) + assert_equal(res_bars, targ_bars) + assert_equal(res_comma, targ_comma) + + +def test_cat_uniform(dtype): + values_1 = xr.DataArray( + [["a", "bb", "cccc"], ["ddddd", "eeee", "fff"]], + dims=["X", "Y"], + ).astype(dtype) + values_2 = xr.DataArray( + [["11111", "222", "33"], ["4", "5555", "66"]], + dims=["X", "Y"], + ) + + targ_blank = xr.DataArray( + [["a11111", "bb222", "cccc33"], ["ddddd4", "eeee5555", "fff66"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_space = xr.DataArray( + [["a 11111", "bb 222", "cccc 33"], ["ddddd 4", "eeee 5555", "fff 66"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_bars = xr.DataArray( + [["a||11111", "bb||222", "cccc||33"], ["ddddd||4", "eeee||5555", "fff||66"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_comma = xr.DataArray( + [["a, 11111", "bb, 222", "cccc, 33"], ["ddddd, 4", "eeee, 5555", "fff, 66"]], + dims=["X", "Y"], + ).astype(dtype) + + res_blank = values_1.str.cat(values_2) + res_add = values_1.str + values_2 + res_space = values_1.str.cat(values_2, sep=" ") + res_bars = values_1.str.cat(values_2, sep="||") + res_comma = values_1.str.cat(values_2, sep=", ") + + assert res_blank.dtype == targ_blank.dtype + assert res_add.dtype == targ_blank.dtype + assert res_space.dtype == targ_space.dtype + assert res_bars.dtype == targ_bars.dtype + assert res_comma.dtype == targ_comma.dtype + + assert_equal(res_blank, targ_blank) + assert_equal(res_add, targ_blank) + assert_equal(res_space, targ_space) + assert_equal(res_bars, targ_bars) + assert_equal(res_comma, targ_comma) + + +def test_cat_broadcast_right(dtype): + values_1 = xr.DataArray( + [["a", "bb", "cccc"], ["ddddd", "eeee", "fff"]], + dims=["X", "Y"], + ).astype(dtype) + values_2 = xr.DataArray( + ["11111", "222", "33"], + dims=["Y"], + ) + + targ_blank = xr.DataArray( + [["a11111", "bb222", "cccc33"], ["ddddd11111", "eeee222", "fff33"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_space = xr.DataArray( + [["a 11111", "bb 222", "cccc 33"], ["ddddd 11111", "eeee 222", "fff 33"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_bars = xr.DataArray( + [["a||11111", "bb||222", "cccc||33"], ["ddddd||11111", "eeee||222", "fff||33"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_comma = xr.DataArray( + [["a, 11111", "bb, 222", "cccc, 33"], ["ddddd, 11111", "eeee, 222", "fff, 33"]], + dims=["X", "Y"], + ).astype(dtype) + + res_blank = values_1.str.cat(values_2) + res_add = values_1.str + values_2 + res_space = values_1.str.cat(values_2, sep=" ") + res_bars = values_1.str.cat(values_2, sep="||") + res_comma = values_1.str.cat(values_2, sep=", ") + + assert res_blank.dtype == targ_blank.dtype + assert res_add.dtype == targ_blank.dtype + assert res_space.dtype == targ_space.dtype + assert res_bars.dtype == targ_bars.dtype + assert res_comma.dtype == targ_comma.dtype + + assert_equal(res_blank, targ_blank) + assert_equal(res_add, targ_blank) + assert_equal(res_space, targ_space) + assert_equal(res_bars, targ_bars) + assert_equal(res_comma, targ_comma) + + +def test_cat_broadcast_left(dtype): + values_1 = xr.DataArray( + ["a", "bb", "cccc"], + dims=["Y"], + ).astype(dtype) + values_2 = xr.DataArray( + [["11111", "222", "33"], ["4", "5555", "66"]], + dims=["X", "Y"], + ) + + targ_blank = ( + xr.DataArray( + [["a11111", "bb222", "cccc33"], ["a4", "bb5555", "cccc66"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + targ_space = ( + xr.DataArray( + [["a 11111", "bb 222", "cccc 33"], ["a 4", "bb 5555", "cccc 66"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + targ_bars = ( + xr.DataArray( + [["a||11111", "bb||222", "cccc||33"], ["a||4", "bb||5555", "cccc||66"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + targ_comma = ( + xr.DataArray( + [["a, 11111", "bb, 222", "cccc, 33"], ["a, 4", "bb, 5555", "cccc, 66"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + res_blank = values_1.str.cat(values_2) + res_add = values_1.str + values_2 + res_space = values_1.str.cat(values_2, sep=" ") + res_bars = values_1.str.cat(values_2, sep="||") + res_comma = values_1.str.cat(values_2, sep=", ") + + assert res_blank.dtype == targ_blank.dtype + assert res_add.dtype == targ_blank.dtype + assert res_space.dtype == targ_space.dtype + assert res_bars.dtype == targ_bars.dtype + assert res_comma.dtype == targ_comma.dtype + + assert_equal(res_blank, targ_blank) + assert_equal(res_add, targ_blank) + assert_equal(res_space, targ_space) + assert_equal(res_bars, targ_bars) + assert_equal(res_comma, targ_comma) + + +def test_cat_broadcast_both(dtype): + values_1 = xr.DataArray( + ["a", "bb", "cccc"], + dims=["Y"], + ).astype(dtype) + values_2 = xr.DataArray( + ["11111", "4"], + dims=["X"], + ) + + targ_blank = ( + xr.DataArray( + [["a11111", "bb11111", "cccc11111"], ["a4", "bb4", "cccc4"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + targ_space = ( + xr.DataArray( + [["a 11111", "bb 11111", "cccc 11111"], ["a 4", "bb 4", "cccc 4"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + targ_bars = ( + xr.DataArray( + [["a||11111", "bb||11111", "cccc||11111"], ["a||4", "bb||4", "cccc||4"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + targ_comma = ( + xr.DataArray( + [["a, 11111", "bb, 11111", "cccc, 11111"], ["a, 4", "bb, 4", "cccc, 4"]], + dims=["X", "Y"], + ) + .astype(dtype) + .T + ) + + res_blank = values_1.str.cat(values_2) + res_add = values_1.str + values_2 + res_space = values_1.str.cat(values_2, sep=" ") + res_bars = values_1.str.cat(values_2, sep="||") + res_comma = values_1.str.cat(values_2, sep=", ") + + assert res_blank.dtype == targ_blank.dtype + assert res_add.dtype == targ_blank.dtype + assert res_space.dtype == targ_space.dtype + assert res_bars.dtype == targ_bars.dtype + assert res_comma.dtype == targ_comma.dtype + + assert_equal(res_blank, targ_blank) + assert_equal(res_add, targ_blank) + assert_equal(res_space, targ_space) + assert_equal(res_bars, targ_bars) + assert_equal(res_comma, targ_comma) + + +def test_cat_multi(): + values_1 = xr.DataArray( + ["11111", "4"], + dims=["X"], + ) + + values_2 = xr.DataArray( + ["a", "bb", "cccc"], + dims=["Y"], + ).astype(np.bytes_) + + values_3 = np.array(3.4) + + values_4 = "" + + values_5 = np.array("", dtype=np.unicode_) + + sep = xr.DataArray( + [" ", ", "], + dims=["ZZ"], + ).astype(np.unicode_) + + expected = xr.DataArray( + [ + [ + ["11111 a 3.4 ", "11111, a, 3.4, , "], + ["11111 bb 3.4 ", "11111, bb, 3.4, , "], + ["11111 cccc 3.4 ", "11111, cccc, 3.4, , "], + ], + [ + ["4 a 3.4 ", "4, a, 3.4, , "], + ["4 bb 3.4 ", "4, bb, 3.4, , "], + ["4 cccc 3.4 ", "4, cccc, 3.4, , "], + ], + ], + dims=["X", "Y", "ZZ"], + ).astype(np.unicode_) + + res = values_1.str.cat(values_2, values_3, values_4, values_5, sep=sep) + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_join_scalar(dtype): + values = xr.DataArray("aaa").astype(dtype) + + targ = xr.DataArray("aaa").astype(dtype) + + res_blank = values.str.join() + res_space = values.str.join(sep=" ") + + assert res_blank.dtype == targ.dtype + assert res_space.dtype == targ.dtype + + assert_identical(res_blank, targ) + assert_identical(res_space, targ) + + +def test_join_vector(dtype): + values = xr.DataArray( + ["a", "bb", "cccc"], + dims=["Y"], + ).astype(dtype) + + targ_blank = xr.DataArray("abbcccc").astype(dtype) + targ_space = xr.DataArray("a bb cccc").astype(dtype) + + res_blank_none = values.str.join() + res_blank_y = values.str.join(dim="Y") + + res_space_none = values.str.join(sep=" ") + res_space_y = values.str.join(dim="Y", sep=" ") + + assert res_blank_none.dtype == targ_blank.dtype + assert res_blank_y.dtype == targ_blank.dtype + assert res_space_none.dtype == targ_space.dtype + assert res_space_y.dtype == targ_space.dtype + + assert_identical(res_blank_none, targ_blank) + assert_identical(res_blank_y, targ_blank) + assert_identical(res_space_none, targ_space) + assert_identical(res_space_y, targ_space) + + +def test_join_2d(dtype): + values = xr.DataArray( + [["a", "bb", "cccc"], ["ddddd", "eeee", "fff"]], + dims=["X", "Y"], + ).astype(dtype) + + targ_blank_x = xr.DataArray( + ["addddd", "bbeeee", "ccccfff"], + dims=["Y"], + ).astype(dtype) + targ_space_x = xr.DataArray( + ["a ddddd", "bb eeee", "cccc fff"], + dims=["Y"], + ).astype(dtype) + + targ_blank_y = xr.DataArray( + ["abbcccc", "dddddeeeefff"], + dims=["X"], + ).astype(dtype) + targ_space_y = xr.DataArray( + ["a bb cccc", "ddddd eeee fff"], + dims=["X"], + ).astype(dtype) + + res_blank_x = values.str.join(dim="X") + res_blank_y = values.str.join(dim="Y") + + res_space_x = values.str.join(dim="X", sep=" ") + res_space_y = values.str.join(dim="Y", sep=" ") + + assert res_blank_x.dtype == targ_blank_x.dtype + assert res_blank_y.dtype == targ_blank_y.dtype + assert res_space_x.dtype == targ_space_x.dtype + assert res_space_y.dtype == targ_space_y.dtype + + assert_identical(res_blank_x, targ_blank_x) + assert_identical(res_blank_y, targ_blank_y) + assert_identical(res_space_x, targ_space_x) + assert_identical(res_space_y, targ_space_y) + + with pytest.raises( + ValueError, match="Dimension must be specified for multidimensional arrays." + ): + values.str.join() + + +def test_join_broadcast(dtype): + values = xr.DataArray( + ["a", "bb", "cccc"], + dims=["X"], + ).astype(dtype) + + sep = xr.DataArray( + [" ", ", "], + dims=["ZZ"], + ).astype(dtype) + + expected = xr.DataArray( + ["a bb cccc", "a, bb, cccc"], + dims=["ZZ"], + ).astype(dtype) + + res = values.str.join(sep=sep) + + assert res.dtype == expected.dtype + assert_identical(res, expected) + + +def test_format_scalar(): + values = xr.DataArray( + ["{}.{Y}.{ZZ}", "{},{},{X},{X}", "{X}-{Y}-{ZZ}"], + dims=["X"], + ).astype(np.unicode_) + + pos0 = 1 + pos1 = 1.2 + pos2 = "2.3" + X = "'test'" + Y = "X" + ZZ = None + W = "NO!" + + expected = xr.DataArray( + ["1.X.None", "1,1.2,'test','test'", "'test'-X-None"], + dims=["X"], + ).astype(np.unicode_) + + res = values.str.format(pos0, pos1, pos2, X=X, Y=Y, ZZ=ZZ, W=W) + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_format_broadcast(): + values = xr.DataArray( + ["{}.{Y}.{ZZ}", "{},{},{X},{X}", "{X}-{Y}-{ZZ}"], + dims=["X"], + ).astype(np.unicode_) + + pos0 = 1 + pos1 = 1.2 + + pos2 = xr.DataArray( + ["2.3", "3.44444"], + dims=["YY"], + ) + + X = "'test'" + Y = "X" + ZZ = None + W = "NO!" + + expected = xr.DataArray( + [ + ["1.X.None", "1.X.None"], + ["1,1.2,'test','test'", "1,1.2,'test','test'"], + ["'test'-X-None", "'test'-X-None"], + ], + dims=["X", "YY"], + ).astype(np.unicode_) + + res = values.str.format(pos0, pos1, pos2, X=X, Y=Y, ZZ=ZZ, W=W) + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_mod_scalar(): + values = xr.DataArray( + ["%s.%s.%s", "%s,%s,%s", "%s-%s-%s"], + dims=["X"], + ).astype(np.unicode_) + + pos0 = 1 + pos1 = 1.2 + pos2 = "2.3" + + expected = xr.DataArray( + ["1.1.2.2.3", "1,1.2,2.3", "1-1.2-2.3"], + dims=["X"], + ).astype(np.unicode_) + + res = values.str % (pos0, pos1, pos2) + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_mod_dict(): + values = xr.DataArray( + ["%(a)s.%(a)s.%(b)s", "%(b)s,%(c)s,%(b)s", "%(c)s-%(b)s-%(a)s"], + dims=["X"], + ).astype(np.unicode_) + + a = 1 + b = 1.2 + c = "2.3" + + expected = xr.DataArray( + ["1.1.1.2", "1.2,2.3,1.2", "2.3-1.2-1"], + dims=["X"], + ).astype(np.unicode_) + + res = values.str % {"a": a, "b": b, "c": c} + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_mod_broadcast_single(): + values = xr.DataArray( + ["%s_1", "%s_2", "%s_3"], + dims=["X"], + ).astype(np.unicode_) + + pos = xr.DataArray( + ["2.3", "3.44444"], + dims=["YY"], + ) + + expected = xr.DataArray( + [["2.3_1", "3.44444_1"], ["2.3_2", "3.44444_2"], ["2.3_3", "3.44444_3"]], + dims=["X", "YY"], + ).astype(np.unicode_) + + res = values.str % pos + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_mod_broadcast_multi(): + values = xr.DataArray( + ["%s.%s.%s", "%s,%s,%s", "%s-%s-%s"], + dims=["X"], + ).astype(np.unicode_) + + pos0 = 1 + pos1 = 1.2 + + pos2 = xr.DataArray( + ["2.3", "3.44444"], + dims=["YY"], + ) + + expected = xr.DataArray( + [ + ["1.1.2.2.3", "1.1.2.3.44444"], + ["1,1.2,2.3", "1,1.2,3.44444"], + ["1-1.2-2.3", "1-1.2-3.44444"], + ], + dims=["X", "YY"], + ).astype(np.unicode_) + + res = values.str % (pos0, pos1, pos2) + + assert res.dtype == expected.dtype + assert_equal(res, expected) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 3750c0715ae..3bbc2c93b31 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1,8 +1,10 @@ import contextlib +import gzip import itertools import math import os.path import pickle +import re import shutil import sys import tempfile @@ -30,9 +32,14 @@ save_mfdataset, ) from xarray.backends.common import robust_getitem +from xarray.backends.h5netcdf_ import H5netcdfBackendEntrypoint from xarray.backends.netcdf3 import _nc3_dtype_coercions -from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding +from xarray.backends.netCDF4_ import ( + NetCDF4BackendEntrypoint, + _extract_nc4_variable_encoding, +) from xarray.backends.pydap_ import PydapDataStore +from xarray.backends.scipy_ import ScipyBackendEntrypoint from xarray.coding.variables import SerializationWarning from xarray.conventions import encode_dataset_coordinates from xarray.core import indexing @@ -50,11 +57,12 @@ has_netCDF4, has_scipy, network, - raises_regex, requires_cfgrib, requires_cftime, requires_dask, + requires_fsspec, requires_h5netcdf, + requires_iris, requires_netCDF4, requires_pseudonetcdf, requires_pydap, @@ -79,12 +87,8 @@ try: import dask import dask.array as da - - dask_version = dask.__version__ except ImportError: - # needed for xfailed tests when dask < 2.4.0 - # remove when min dask > 2.4.0 - dask_version = "10.0" + pass ON_WINDOWS = sys.platform == "win32" default_value = object() @@ -734,7 +738,7 @@ def find_and_validate_array(obj): elif isinstance(obj.array, dask_array_type): assert isinstance(obj, indexing.DaskIndexingAdapter) elif isinstance(obj.array, pd.Index): - assert isinstance(obj, indexing.PandasIndexAdapter) + assert isinstance(obj, indexing.PandasIndexingAdapter) else: raise TypeError( "{} is wrapped by {}".format(type(obj.array), type(obj)) @@ -777,7 +781,7 @@ def test_dropna(self): assert_identical(expected, actual) def test_ondisk_after_print(self): - """ Make sure print does not load file into memory """ + """Make sure print does not load file into memory""" in_memory = create_test_data() with self.roundtrip(in_memory) as on_disk: repr(on_disk) @@ -857,6 +861,118 @@ def test_roundtrip_mask_and_scale(self, decoded_fn, encoded_fn): assert decoded.variables[k].dtype == actual.variables[k].dtype assert_allclose(decoded, actual, decode_bytes=False) + @staticmethod + def _create_cf_dataset(): + original = Dataset( + dict( + variable=( + ("ln_p", "latitude", "longitude"), + np.arange(8, dtype="f4").reshape(2, 2, 2), + {"ancillary_variables": "std_devs det_lim"}, + ), + std_devs=( + ("ln_p", "latitude", "longitude"), + np.arange(0.1, 0.9, 0.1).reshape(2, 2, 2), + {"standard_name": "standard_error"}, + ), + det_lim=( + (), + 0.1, + {"standard_name": "detection_minimum"}, + ), + ), + dict( + latitude=("latitude", [0, 1], {"units": "degrees_north"}), + longitude=("longitude", [0, 1], {"units": "degrees_east"}), + latlon=((), -1, {"grid_mapping_name": "latitude_longitude"}), + latitude_bnds=(("latitude", "bnds2"), [[0, 1], [1, 2]]), + longitude_bnds=(("longitude", "bnds2"), [[0, 1], [1, 2]]), + areas=( + ("latitude", "longitude"), + [[1, 1], [1, 1]], + {"units": "degree^2"}, + ), + ln_p=( + "ln_p", + [1.0, 0.5], + { + "standard_name": "atmosphere_ln_pressure_coordinate", + "computed_standard_name": "air_pressure", + }, + ), + P0=((), 1013.25, {"units": "hPa"}), + ), + ) + original["variable"].encoding.update( + {"cell_measures": "area: areas", "grid_mapping": "latlon"}, + ) + original.coords["latitude"].encoding.update( + dict(grid_mapping="latlon", bounds="latitude_bnds") + ) + original.coords["longitude"].encoding.update( + dict(grid_mapping="latlon", bounds="longitude_bnds") + ) + original.coords["ln_p"].encoding.update({"formula_terms": "p0: P0 lev : ln_p"}) + return original + + def test_grid_mapping_and_bounds_are_not_coordinates_in_file(self): + original = self._create_cf_dataset() + with create_tmp_file() as tmp_file: + original.to_netcdf(tmp_file) + with open_dataset(tmp_file, decode_coords=False) as ds: + assert ds.coords["latitude"].attrs["bounds"] == "latitude_bnds" + assert ds.coords["longitude"].attrs["bounds"] == "longitude_bnds" + assert "coordinates" not in ds["variable"].attrs + assert "coordinates" not in ds.attrs + + def test_coordinate_variables_after_dataset_roundtrip(self): + original = self._create_cf_dataset() + with self.roundtrip(original, open_kwargs={"decode_coords": "all"}) as actual: + assert_identical(actual, original) + + with self.roundtrip(original) as actual: + expected = original.reset_coords( + ["latitude_bnds", "longitude_bnds", "areas", "P0", "latlon"] + ) + # equal checks that coords and data_vars are equal which + # should be enough + # identical would require resetting a number of attributes + # skip that. + assert_equal(actual, expected) + + def test_grid_mapping_and_bounds_are_coordinates_after_dataarray_roundtrip(self): + original = self._create_cf_dataset() + # The DataArray roundtrip should have the same warnings as the + # Dataset, but we already tested for those, so just go for the + # new warnings. It would appear that there is no way to tell + # pytest "This warning and also this warning should both be + # present". + # xarray/tests/test_conventions.py::TestCFEncodedDataStore + # needs the to_dataset. The other backends should be fine + # without it. + with pytest.warns( + UserWarning, + match=( + r"Variable\(s\) referenced in bounds not in variables: " + r"\['l(at|ong)itude_bnds'\]" + ), + ): + with self.roundtrip( + original["variable"].to_dataset(), open_kwargs={"decode_coords": "all"} + ) as actual: + assert_identical(actual, original["variable"].to_dataset()) + + @requires_iris + def test_coordinate_variables_after_iris_roundtrip(self): + original = self._create_cf_dataset() + iris_cube = original["variable"].to_iris() + actual = DataArray.from_iris(iris_cube) + # Bounds will be missing (xfail) + del original.coords["latitude_bnds"], original.coords["longitude_bnds"] + # Ancillary vars will be missing + # Those are data_vars, and will be dropped when grabbing the variable + assert_identical(actual, original["variable"]) + def test_coordinates_encoding(self): def equals_latlon(obj): return obj == "lat lon" or obj == "lon lat" @@ -944,12 +1060,12 @@ def test_encoding_kwarg(self): assert ds.x.encoding == {} kwargs = dict(encoding={"x": {"foo": "bar"}}) - with raises_regex(ValueError, "unexpected encoding"): + with pytest.raises(ValueError, match=r"unexpected encoding"): with self.roundtrip(ds, save_kwargs=kwargs) as actual: pass kwargs = dict(encoding={"x": "foo"}) - with raises_regex(ValueError, "must be castable"): + with pytest.raises(ValueError, match=r"must be castable"): with self.roundtrip(ds, save_kwargs=kwargs) as actual: pass @@ -1056,8 +1172,8 @@ def test_append_with_invalid_dim_raises(self): self.save(data, tmp_file, mode="w") data["var9"] = data["var2"] * 3 data = data.isel(dim1=slice(2, 6)) # modify one dimension - with raises_regex( - ValueError, "Unable to update size for existing dimension" + with pytest.raises( + ValueError, match=r"Unable to update size for existing dimension" ): self.save(data, tmp_file, mode="a") @@ -1065,7 +1181,7 @@ def test_multiindex_not_implemented(self): ds = Dataset(coords={"y": ("x", [1, 2]), "z": ("x", ["a", "b"])}).set_index( x=["y", "z"] ) - with raises_regex(NotImplementedError, "MultiIndex"): + with pytest.raises(NotImplementedError, match=r"MultiIndex"): with self.roundtrip(ds): pass @@ -1124,7 +1240,7 @@ def test_open_group(self): # check that missing group raises appropriate exception with pytest.raises(IOError): open_dataset(tmp_file, group="bar") - with raises_regex(ValueError, "must be a string"): + with pytest.raises(ValueError, match=r"must be a string"): open_dataset(tmp_file, group=(1, 2, 3)) def test_open_subgroup(self): @@ -1558,6 +1674,9 @@ class ZarrBase(CFEncodedBase): DIMENSION_KEY = "_ARRAY_DIMENSIONS" + def create_zarr_target(self): + raise NotImplementedError + @contextlib.contextmanager def create_store(self): with self.create_zarr_target() as store_target: @@ -1584,8 +1703,8 @@ def roundtrip( with self.open(store_target, **open_kwargs) as ds: yield ds - def test_roundtrip_consolidated(self): - pytest.importorskip("zarr", minversion="2.2.1.dev2") + @pytest.mark.parametrize("consolidated", [False, True, None]) + def test_roundtrip_consolidated(self, consolidated): expected = create_test_data() with self.roundtrip( expected, @@ -1595,6 +1714,17 @@ def test_roundtrip_consolidated(self): self.check_dtypes_roundtripped(expected, actual) assert_identical(expected, actual) + def test_read_non_consolidated_warning(self): + expected = create_test_data() + with self.create_zarr_target() as store: + expected.to_zarr(store, consolidated=False) + with pytest.warns( + RuntimeWarning, + match="Failed to open Zarr store with consolidated", + ): + with xr.open_zarr(store) as ds: + assert_identical(ds, expected) + def test_with_chunkstore(self): expected = create_test_data() with self.create_zarr_target() as store_target, self.create_zarr_target() as chunk_store: @@ -1747,25 +1877,32 @@ def test_chunk_encoding_with_dask(self): # should fail if dask_chunks are irregular... ds_chunk_irreg = ds.chunk({"x": (5, 4, 3)}) - with raises_regex(ValueError, "uniform chunk sizes."): + with pytest.raises(ValueError, match=r"uniform chunk sizes."): with self.roundtrip(ds_chunk_irreg) as actual: pass # should fail if encoding["chunks"] clashes with dask_chunks badenc = ds.chunk({"x": 4}) badenc.var1.encoding["chunks"] = (6,) - with raises_regex(NotImplementedError, "named 'var1' would overlap"): + with pytest.raises(NotImplementedError, match=r"named 'var1' would overlap"): with self.roundtrip(badenc) as actual: pass + # unless... + with self.roundtrip(badenc, save_kwargs={"safe_chunks": False}) as actual: + # don't actually check equality because the data could be corrupted + pass + badenc.var1.encoding["chunks"] = (2,) - with raises_regex(ValueError, "Specified Zarr chunk encoding"): + with pytest.raises(NotImplementedError, match=r"Specified Zarr chunk encoding"): with self.roundtrip(badenc) as actual: pass badenc = badenc.chunk({"x": (3, 3, 6)}) badenc.var1.encoding["chunks"] = (3,) - with raises_regex(ValueError, "incompatible with this encoding"): + with pytest.raises( + NotImplementedError, match=r"incompatible with this encoding" + ): with self.roundtrip(badenc) as actual: pass @@ -1788,9 +1925,13 @@ def test_chunk_encoding_with_dask(self): # TODO: remove this failure once syncronized overlapping writes are # supported by xarray ds_chunk4["var1"].encoding.update({"chunks": 5}) - with pytest.raises(NotImplementedError): + with pytest.raises(NotImplementedError, match=r"named 'var1' would overlap"): with self.roundtrip(ds_chunk4) as actual: pass + # override option + with self.roundtrip(ds_chunk4, save_kwargs={"safe_chunks": False}) as actual: + # don't actually check equality because the data could be corrupted + pass def test_hidden_zarr_keys(self): expected = create_test_data() @@ -1816,7 +1957,6 @@ def test_hidden_zarr_keys(self): with xr.decode_cf(store): pass - @pytest.mark.skipif(LooseVersion(dask_version) < "2.4", reason="dask GH5334") @pytest.mark.parametrize("group", [None, "group1"]) def test_write_persistence_modes(self, group): original = create_test_data() @@ -1894,10 +2034,28 @@ def test_encoding_kwarg_fixed_width_string(self): def test_dataset_caching(self): super().test_dataset_caching() - @pytest.mark.skipif(LooseVersion(dask_version) < "2.4", reason="dask GH5334") def test_append_write(self): super().test_append_write() + def test_append_with_mode_rplus_success(self): + original = Dataset({"foo": ("x", [1])}) + modified = Dataset({"foo": ("x", [2])}) + with self.create_zarr_target() as store: + original.to_zarr(store) + modified.to_zarr(store, mode="r+") + with self.open(store) as actual: + assert_identical(actual, modified) + + def test_append_with_mode_rplus_fails(self): + original = Dataset({"foo": ("x", [1])}) + modified = Dataset({"bar": ("x", [2])}) + with self.create_zarr_target() as store: + original.to_zarr(store) + with pytest.raises( + ValueError, match="dataset contains non-pre-existing variables" + ): + modified.to_zarr(store, mode="r+") + def test_append_with_invalid_dim_raises(self): ds, ds_to_append, _ = create_append_test_data() with self.create_zarr_target() as store_target: @@ -1958,7 +2116,6 @@ def test_check_encoding_is_consistent_after_append(self): xr.concat([ds, ds_to_append], dim="time"), ) - @pytest.mark.skipif(LooseVersion(dask_version) < "2.4", reason="dask GH5334") def test_append_with_new_variable(self): ds, ds_to_append, ds_with_new_var = create_append_test_data() @@ -2054,31 +2211,64 @@ def test_write_region(self, consolidated, compute, use_dask): assert_identical(actual, zeros) for i in range(0, 10, 2): region = {"x": slice(i, i + 2)} - nonzeros.isel(region).to_zarr(store, region=region) + nonzeros.isel(region).to_zarr( + store, region=region, consolidated=consolidated + ) with xr.open_zarr(store, consolidated=consolidated) as actual: assert_identical(actual, nonzeros) + @pytest.mark.parametrize("mode", [None, "r+", "a"]) + def test_write_region_mode(self, mode): + zeros = Dataset({"u": (("x",), np.zeros(10))}) + nonzeros = Dataset({"u": (("x",), np.arange(1, 11))}) + with self.create_zarr_target() as store: + zeros.to_zarr(store) + for region in [{"x": slice(5)}, {"x": slice(5, 10)}]: + nonzeros.isel(region).to_zarr(store, region=region, mode=mode) + with xr.open_zarr(store) as actual: + assert_identical(actual, nonzeros) + @requires_dask - def test_write_region_metadata(self): - """Metadata should not be overwritten in "region" writes.""" - template = Dataset( - {"u": (("x",), np.zeros(10), {"variable": "template"})}, - attrs={"global": "template"}, + def test_write_preexisting_override_metadata(self): + """Metadata should be overriden if mode="a" but not in mode="r+".""" + original = Dataset( + {"u": (("x",), np.zeros(10), {"variable": "original"})}, + attrs={"global": "original"}, ) - data = Dataset( - {"u": (("x",), np.arange(1, 11), {"variable": "data"})}, - attrs={"global": "data"}, + both_modified = Dataset( + {"u": (("x",), np.ones(10), {"variable": "modified"})}, + attrs={"global": "modified"}, ) - expected = Dataset( - {"u": (("x",), np.arange(1, 11), {"variable": "template"})}, - attrs={"global": "template"}, + global_modified = Dataset( + {"u": (("x",), np.ones(10), {"variable": "original"})}, + attrs={"global": "modified"}, + ) + only_new_data = Dataset( + {"u": (("x",), np.ones(10), {"variable": "original"})}, + attrs={"global": "original"}, ) with self.create_zarr_target() as store: - template.to_zarr(store, compute=False) - data.to_zarr(store, region={"x": slice(None)}) + original.to_zarr(store, compute=False) + both_modified.to_zarr(store, mode="a") with self.open(store) as actual: - assert_identical(actual, expected) + # NOTE: this arguably incorrect -- we should probably be + # overriding the variable metadata, too. See the TODO note in + # ZarrStore.set_variables. + assert_identical(actual, global_modified) + + with self.create_zarr_target() as store: + original.to_zarr(store, compute=False) + both_modified.to_zarr(store, mode="r+") + with self.open(store) as actual: + assert_identical(actual, only_new_data) + + with self.create_zarr_target() as store: + original.to_zarr(store, compute=False) + # with region, the default mode becomes r+ + both_modified.to_zarr(store, region={"x": slice(None)}) + with self.open(store) as actual: + assert_identical(actual, only_new_data) def test_write_region_errors(self): data = Dataset({"u": (("x",), np.arange(5))}) @@ -2098,47 +2288,50 @@ def setup_and_verify_store(expected=data): data2.to_zarr(store, region={"x": slice(2)}) with setup_and_verify_store() as store: - with raises_regex(ValueError, "cannot use consolidated=True"): - data2.to_zarr(store, region={"x": slice(2)}, consolidated=True) - - with setup_and_verify_store() as store: - with raises_regex( - ValueError, "cannot set region unless mode='a' or mode=None" + with pytest.raises( + ValueError, + match=re.escape( + "cannot set region unless mode='a', mode='r+' or mode=None" + ), ): data.to_zarr(store, region={"x": slice(None)}, mode="w") with setup_and_verify_store() as store: - with raises_regex(TypeError, "must be a dict"): + with pytest.raises(TypeError, match=r"must be a dict"): data.to_zarr(store, region=slice(None)) with setup_and_verify_store() as store: - with raises_regex(TypeError, "must be slice objects"): + with pytest.raises(TypeError, match=r"must be slice objects"): data2.to_zarr(store, region={"x": [0, 1]}) with setup_and_verify_store() as store: - with raises_regex(ValueError, "step on all slices"): + with pytest.raises(ValueError, match=r"step on all slices"): data2.to_zarr(store, region={"x": slice(None, None, 2)}) with setup_and_verify_store() as store: - with raises_regex( - ValueError, "all keys in ``region`` are not in Dataset dimensions" + with pytest.raises( + ValueError, + match=r"all keys in ``region`` are not in Dataset dimensions", ): data.to_zarr(store, region={"y": slice(None)}) with setup_and_verify_store() as store: - with raises_regex( + with pytest.raises( ValueError, - "all variables in the dataset to write must have at least one dimension in common", + match=r"all variables in the dataset to write must have at least one dimension in common", ): data2.assign(v=2).to_zarr(store, region={"x": slice(2)}) with setup_and_verify_store() as store: - with raises_regex(ValueError, "cannot list the same dimension in both"): + with pytest.raises( + ValueError, match=r"cannot list the same dimension in both" + ): data.to_zarr(store, region={"x": slice(None)}, append_dim="x") with setup_and_verify_store() as store: - with raises_regex( - ValueError, "variable 'u' already exists with different dimension sizes" + with pytest.raises( + ValueError, + match=r"variable 'u' already exists with different dimension sizes", ): data2.to_zarr(store, region={"x": slice(3)}) @@ -2173,10 +2366,10 @@ def test_chunk_encoding_with_partial_dask_chunks(self): def test_open_zarr_use_cftime(self): ds = create_test_data() with self.create_zarr_target() as store_target: - ds.to_zarr(store_target, consolidated=True) - ds_a = xr.open_zarr(store_target, consolidated=True) + ds.to_zarr(store_target) + ds_a = xr.open_zarr(store_target) assert_identical(ds, ds_a) - ds_b = xr.open_zarr(store_target, consolidated=True, use_cftime=True) + ds_b = xr.open_zarr(store_target, use_cftime=True) assert xr.coding.times.contains_cftime_datetimes(ds_b.time) @@ -2261,7 +2454,7 @@ def create_store(self): def test_array_attrs(self): ds = Dataset(attrs={"foo": [[1, 2], [3, 4]]}) - with raises_regex(ValueError, "must be 1-dimensional"): + with pytest.raises(ValueError, match=r"must be 1-dimensional"): with self.roundtrip(ds): pass @@ -2282,7 +2475,7 @@ def test_nc4_scipy(self): with nc4.Dataset(tmp_file, "w", format="NETCDF4") as rootgrp: rootgrp.createGroup("foo") - with raises_regex(TypeError, "pip install netcdf4"): + with pytest.raises(TypeError, match=r"pip install netcdf4"): open_dataset(tmp_file, engine="scipy") @@ -2302,7 +2495,7 @@ def create_store(self): def test_encoding_kwarg_vlen_string(self): original = Dataset({"x": ["foo", "bar", "baz"]}) kwargs = dict(encoding={"x": {"dtype": str}}) - with raises_regex(ValueError, "encoding dtype=str for vlen"): + with pytest.raises(ValueError, match=r"encoding dtype=str for vlen"): with self.roundtrip(original, save_kwargs=kwargs): pass @@ -2334,18 +2527,18 @@ def test_write_store(self): @requires_scipy def test_engine(self): data = create_test_data() - with raises_regex(ValueError, "unrecognized engine"): + with pytest.raises(ValueError, match=r"unrecognized engine"): data.to_netcdf("foo.nc", engine="foobar") - with raises_regex(ValueError, "invalid engine"): + with pytest.raises(ValueError, match=r"invalid engine"): data.to_netcdf(engine="netcdf4") with create_tmp_file() as tmp_file: data.to_netcdf(tmp_file) - with raises_regex(ValueError, "unrecognized engine"): + with pytest.raises(ValueError, match=r"unrecognized engine"): open_dataset(tmp_file, engine="foobar") netcdf_bytes = data.to_netcdf() - with raises_regex(ValueError, "unrecognized engine"): + with pytest.raises(ValueError, match=r"unrecognized engine"): open_dataset(BytesIO(netcdf_bytes), engine="foobar") def test_cross_engine_read_write_netcdf3(self): @@ -2427,6 +2620,14 @@ def test_complex(self, invalid_netcdf, warntype, num_warns): assert recorded_num_warns == num_warns + def test_numpy_bool_(self): + # h5netcdf loads booleans as numpy.bool_, this type needs to be supported + # when writing invalid_netcdf datasets in order to support a roundtrip + expected = Dataset({"x": ("y", np.ones(5), {"numpy_bool": np.bool_(True)})}) + save_kwargs = {"invalid_netcdf": True} + with self.roundtrip(expected, save_kwargs=save_kwargs) as actual: + assert_identical(expected, actual) + def test_cross_engine_read_write_netcdf4(self): # Drop dim3, because its labels include strings. These appear to be # not properly read with python-netCDF4, which converts them into @@ -2527,8 +2728,8 @@ def test_compression_check_encoding_h5py(self): # Incompatible encodings cause a crash with create_tmp_file() as tmp_file: - with raises_regex( - ValueError, "'zlib' and 'compression' encodings mismatch" + with pytest.raises( + ValueError, match=r"'zlib' and 'compression' encodings mismatch" ): data.to_netcdf( tmp_file, @@ -2537,8 +2738,9 @@ def test_compression_check_encoding_h5py(self): ) with create_tmp_file() as tmp_file: - with raises_regex( - ValueError, "'complevel' and 'compression_opts' encodings mismatch" + with pytest.raises( + ValueError, + match=r"'complevel' and 'compression_opts' encodings mismatch", ): data.to_netcdf( tmp_file, @@ -2568,6 +2770,7 @@ def test_dump_encodings_h5py(self): @requires_h5netcdf +@requires_netCDF4 class TestH5NetCDFAlreadyOpen: def test_open_dataset_group(self): import h5netcdf @@ -2578,13 +2781,19 @@ def test_open_dataset_group(self): v = group.createVariable("x", "int") v[...] = 42 - h5 = h5netcdf.File(tmp_file, mode="r") + kwargs = {} + if LooseVersion(h5netcdf.__version__) >= LooseVersion( + "0.10.0" + ) and LooseVersion(h5netcdf.core.h5py.__version__) >= LooseVersion("3.0.0"): + kwargs = dict(decode_vlen_strings=True) + + h5 = h5netcdf.File(tmp_file, mode="r", **kwargs) store = backends.H5NetCDFStore(h5["g"]) with open_dataset(store) as ds: expected = Dataset({"x": ((), 42)}) assert_identical(expected, ds) - h5 = h5netcdf.File(tmp_file, mode="r") + h5 = h5netcdf.File(tmp_file, mode="r", **kwargs) store = backends.H5NetCDFStore(h5, group="g") with open_dataset(store) as ds: expected = Dataset({"x": ((), 42)}) @@ -2599,7 +2808,13 @@ def test_deepcopy(self): v = nc.createVariable("y", np.int32, ("x",)) v[:] = np.arange(10) - h5 = h5netcdf.File(tmp_file, mode="r") + kwargs = {} + if LooseVersion(h5netcdf.__version__) >= LooseVersion( + "0.10.0" + ) and LooseVersion(h5netcdf.core.h5py.__version__) >= LooseVersion("3.0.0"): + kwargs = dict(decode_vlen_strings=True) + + h5 = h5netcdf.File(tmp_file, mode="r", **kwargs) store = backends.H5NetCDFStore(h5) with open_dataset(store) as ds: copied = ds.copy(deep=True) @@ -2612,23 +2827,27 @@ class TestH5NetCDFFileObject(TestH5NetCDFData): engine = "h5netcdf" def test_open_badbytes(self): - with raises_regex(ValueError, "HDF5 as bytes"): + with pytest.raises(ValueError, match=r"HDF5 as bytes"): with open_dataset(b"\211HDF\r\n\032\n", engine="h5netcdf"): pass - with raises_regex(ValueError, "cannot guess the engine"): + with pytest.raises( + ValueError, match=r"match in any of xarray's currently installed IO" + ): with open_dataset(b"garbage"): pass - with raises_regex(ValueError, "can only read bytes"): + with pytest.raises(ValueError, match=r"can only read bytes"): with open_dataset(b"garbage", engine="netcdf4"): pass - with raises_regex(ValueError, "not the signature of a valid netCDF file"): + with pytest.raises( + ValueError, match=r"not the signature of a valid netCDF4 file" + ): with open_dataset(BytesIO(b"garbage"), engine="h5netcdf"): pass def test_open_twice(self): expected = create_test_data() expected.attrs["foo"] = "bar" - with raises_regex(ValueError, "read/write pointer not at the start"): + with pytest.raises(ValueError, match=r"read/write pointer not at the start"): with create_tmp_file() as tmp_file: expected.to_netcdf(tmp_file, engine="h5netcdf") with open(tmp_file, "rb") as f: @@ -2636,6 +2855,7 @@ def test_open_twice(self): with open_dataset(f, engine="h5netcdf"): pass + @requires_scipy def test_open_fileobj(self): # open in-memory datasets instead of local file paths expected = create_test_data().drop_vars("dim3") @@ -2657,12 +2877,23 @@ def test_open_fileobj(self): assert_identical(expected, actual) f.seek(0) - with raises_regex(TypeError, "not a valid NetCDF 3"): + with pytest.raises(TypeError, match="not a valid NetCDF 3"): open_dataset(f, engine="scipy") + # TOOD: this additional open is required since scipy seems to close the file + # when it fails on the TypeError (though didn't when we used + # `raises_regex`?). Ref https://github.com/pydata/xarray/pull/5191 + with open(tmp_file, "rb") as f: f.seek(8) - with raises_regex(ValueError, "cannot guess the engine"): - open_dataset(f) + with pytest.raises( + ValueError, + match="match in any of xarray's currently installed IO", + ): + with pytest.warns( + RuntimeWarning, + match=re.escape("'h5netcdf' fails while guessing"), + ): + open_dataset(f) @requires_h5netcdf @@ -2857,19 +3088,75 @@ def gen_datasets_with_common_coord_and_time(self): return ds1, ds2 - @pytest.mark.parametrize("combine", ["nested", "by_coords"]) + @pytest.mark.parametrize( + "combine, concat_dim", [("nested", "t"), ("by_coords", None)] + ) @pytest.mark.parametrize("opt", ["all", "minimal", "different"]) @pytest.mark.parametrize("join", ["outer", "inner", "left", "right"]) - def test_open_mfdataset_does_same_as_concat(self, combine, opt, join): + def test_open_mfdataset_does_same_as_concat(self, combine, concat_dim, opt, join): with self.setup_files_and_datasets() as (files, [ds1, ds2]): if combine == "by_coords": files.reverse() with open_mfdataset( - files, data_vars=opt, combine=combine, concat_dim="t", join=join + files, data_vars=opt, combine=combine, concat_dim=concat_dim, join=join ) as ds: ds_expect = xr.concat([ds1, ds2], data_vars=opt, dim="t", join=join) assert_identical(ds, ds_expect) + @pytest.mark.parametrize( + ["combine_attrs", "attrs", "expected", "expect_error"], + ( + pytest.param("drop", [{"a": 1}, {"a": 2}], {}, False, id="drop"), + pytest.param( + "override", [{"a": 1}, {"a": 2}], {"a": 1}, False, id="override" + ), + pytest.param( + "no_conflicts", [{"a": 1}, {"a": 2}], None, True, id="no_conflicts" + ), + pytest.param( + "identical", + [{"a": 1, "b": 2}, {"a": 1, "c": 3}], + None, + True, + id="identical", + ), + pytest.param( + "drop_conflicts", + [{"a": 1, "b": 2}, {"b": -1, "c": 3}], + {"a": 1, "c": 3}, + False, + id="drop_conflicts", + ), + ), + ) + def test_open_mfdataset_dataset_combine_attrs( + self, combine_attrs, attrs, expected, expect_error + ): + with self.setup_files_and_datasets() as (files, [ds1, ds2]): + # Give the files an inconsistent attribute + for i, f in enumerate(files): + ds = open_dataset(f).load() + ds.attrs = attrs[i] + ds.close() + ds.to_netcdf(f) + + if expect_error: + with pytest.raises(xr.MergeError): + xr.open_mfdataset( + files, + combine="nested", + concat_dim="t", + combine_attrs=combine_attrs, + ) + else: + with xr.open_mfdataset( + files, + combine="nested", + concat_dim="t", + combine_attrs=combine_attrs, + ) as ds: + assert ds.attrs == expected + def test_open_mfdataset_dataset_attr_by_coords(self): """ Case when an attribute differs across the multiple files @@ -2882,7 +3169,7 @@ def test_open_mfdataset_dataset_attr_by_coords(self): ds.close() ds.to_netcdf(f) - with xr.open_mfdataset(files, combine="by_coords", concat_dim="t") as ds: + with xr.open_mfdataset(files, combine="nested", concat_dim="t") as ds: assert ds.test_dataset_attr == 10 def test_open_mfdataset_dataarray_attr_by_coords(self): @@ -2897,18 +3184,24 @@ def test_open_mfdataset_dataarray_attr_by_coords(self): ds.close() ds.to_netcdf(f) - with xr.open_mfdataset(files, combine="by_coords", concat_dim="t") as ds: + with xr.open_mfdataset(files, combine="nested", concat_dim="t") as ds: assert ds["v1"].test_dataarray_attr == 0 - @pytest.mark.parametrize("combine", ["nested", "by_coords"]) + @pytest.mark.parametrize( + "combine, concat_dim", [("nested", "t"), ("by_coords", None)] + ) @pytest.mark.parametrize("opt", ["all", "minimal", "different"]) - def test_open_mfdataset_exact_join_raises_error(self, combine, opt): + def test_open_mfdataset_exact_join_raises_error(self, combine, concat_dim, opt): with self.setup_files_and_datasets(fuzz=0.1) as (files, [ds1, ds2]): if combine == "by_coords": files.reverse() - with raises_regex(ValueError, "indexes along dimension"): + with pytest.raises(ValueError, match=r"indexes along dimension"): open_mfdataset( - files, data_vars=opt, combine=combine, concat_dim="t", join="exact" + files, + data_vars=opt, + combine=combine, + concat_dim=concat_dim, + join="exact", ) def test_common_coord_when_datavars_all(self): @@ -3038,12 +3331,19 @@ def test_open_mfdataset(self): ) as actual: assert actual.foo.variable.data.chunks == ((3, 2, 3, 2),) - with raises_regex(IOError, "no files to open"): + with pytest.raises(IOError, match=r"no files to open"): open_mfdataset("foo-bar-baz-*.nc") - - with raises_regex(ValueError, "wild-card"): + with pytest.raises(ValueError, match=r"wild-card"): open_mfdataset("http://some/remote/uri") + @requires_fsspec + def test_open_mfdataset_no_files(self): + pytest.importorskip("aiobotocore") + + # glob is attempted as of #4823, but finds no files + with pytest.raises(OSError, match=r"no files"): + open_mfdataset("http://some/remote/uri", engine="zarr") + def test_open_mfdataset_2d(self): original = Dataset({"foo": (["x", "y"], np.random.randn(10, 8))}) with create_tmp_file() as tmp1: @@ -3136,7 +3436,7 @@ def test_attrs_mfdataset(self): # first dataset loaded assert actual.test1 == ds1.test1 # attributes from ds2 are not retained, e.g., - with raises_regex(AttributeError, "no attribute"): + with pytest.raises(AttributeError, match=r"no attribute"): actual.test2 def test_open_mfdataset_attrs_file(self): @@ -3185,6 +3485,19 @@ def test_open_mfdataset_auto_combine(self): with open_mfdataset([tmp2, tmp1], combine="by_coords") as actual: assert_identical(original, actual) + # TODO check for an error instead of a warning once deprecated + def test_open_mfdataset_raise_on_bad_combine_args(self): + # Regression test for unhelpful error shown in #5230 + original = Dataset({"foo": ("x", np.random.randn(10)), "x": np.arange(10)}) + with create_tmp_file() as tmp1: + with create_tmp_file() as tmp2: + original.isel(x=slice(5)).to_netcdf(tmp1) + original.isel(x=slice(5, 10)).to_netcdf(tmp2) + with pytest.warns( + DeprecationWarning, match="`concat_dim` has no effect" + ): + open_mfdataset([tmp1, tmp2], concat_dim="x") + @pytest.mark.xfail(reason="mfdataset loses encoding currently.") def test_encoding_mfdataset(self): original = Dataset( @@ -3235,15 +3548,15 @@ def test_save_mfdataset_roundtrip(self): def test_save_mfdataset_invalid(self): ds = Dataset() - with raises_regex(ValueError, "cannot use mode"): + with pytest.raises(ValueError, match=r"cannot use mode"): save_mfdataset([ds, ds], ["same", "same"]) - with raises_regex(ValueError, "same length"): + with pytest.raises(ValueError, match=r"same length"): save_mfdataset([ds, ds], ["only one path"]) def test_save_mfdataset_invalid_dataarray(self): # regression test for GH1555 da = DataArray([1, 2]) - with raises_regex(TypeError, "supports writing Dataset"): + with pytest.raises(TypeError, match=r"supports writing Dataset"): save_mfdataset([da], ["dataarray"]) def test_save_mfdataset_pathlib_roundtrip(self): @@ -3375,6 +3688,7 @@ def test_dataarray_compute(self): assert computed._in_memory assert_allclose(actual, computed, decode_bytes=False) + @pytest.mark.xfail def test_save_mfdataset_compute_false_roundtrip(self): from dask.delayed import Delayed @@ -4382,7 +4696,7 @@ def test_rasterio_vrt_network(self): class TestEncodingInvalid: def test_extract_nc4_variable_encoding(self): var = xr.Variable(("x",), [1, 2, 3], {}, {"foo": "bar"}) - with raises_regex(ValueError, "unexpected encoding"): + with pytest.raises(ValueError, match=r"unexpected encoding"): _extract_nc4_variable_encoding(var, raise_on_invalid=True) var = xr.Variable(("x",), [1, 2, 3], {}, {"chunking": (2, 1)}) @@ -4402,7 +4716,7 @@ def test_extract_nc4_variable_encoding(self): def test_extract_h5nc_encoding(self): # not supported with h5netcdf (yet) var = xr.Variable(("x",), [1, 2, 3], {}, {"least_sigificant_digit": 2}) - with raises_regex(ValueError, "unexpected encoding"): + with pytest.raises(ValueError, match=r"unexpected encoding"): _extract_nc4_variable_encoding(var, raise_on_invalid=True) @@ -4436,17 +4750,17 @@ def new_dataset_and_coord_attrs(): ds, attrs = new_dataset_and_attrs() attrs[123] = "test" - with raises_regex(TypeError, "Invalid name for attr: 123"): + with pytest.raises(TypeError, match=r"Invalid name for attr: 123"): ds.to_netcdf("test.nc") ds, attrs = new_dataset_and_attrs() attrs[MiscObject()] = "test" - with raises_regex(TypeError, "Invalid name for attr: "): + with pytest.raises(TypeError, match=r"Invalid name for attr: "): ds.to_netcdf("test.nc") ds, attrs = new_dataset_and_attrs() attrs[""] = "test" - with raises_regex(ValueError, "Invalid name for attr '':"): + with pytest.raises(ValueError, match=r"Invalid name for attr '':"): ds.to_netcdf("test.nc") # This one should work @@ -4457,12 +4771,12 @@ def new_dataset_and_coord_attrs(): ds, attrs = new_dataset_and_attrs() attrs["test"] = {"a": 5} - with raises_regex(TypeError, "Invalid value for attr 'test'"): + with pytest.raises(TypeError, match=r"Invalid value for attr 'test'"): ds.to_netcdf("test.nc") ds, attrs = new_dataset_and_attrs() attrs["test"] = MiscObject() - with raises_regex(TypeError, "Invalid value for attr 'test'"): + with pytest.raises(TypeError, match=r"Invalid value for attr 'test'"): ds.to_netcdf("test.nc") ds, attrs = new_dataset_and_attrs() @@ -4748,7 +5062,7 @@ def test_use_cftime_false_nonstandard_calendar(calendar, units_year): @pytest.mark.parametrize("engine", ["netcdf4", "scipy"]) def test_invalid_netcdf_raises(engine): data = create_test_data() - with raises_regex(ValueError, "unrecognized option 'invalid_netcdf'"): + with pytest.raises(ValueError, match=r"unrecognized option 'invalid_netcdf'"): data.to_netcdf("foo.nc", engine=engine, invalid_netcdf=True) @@ -4793,18 +5107,62 @@ def test_extract_zarr_variable_encoding(): # raises on invalid var = xr.Variable("x", [1, 2], encoding={"foo": (1,)}) - with raises_regex(ValueError, "unexpected encoding parameters"): + with pytest.raises(ValueError, match=r"unexpected encoding parameters"): actual = backends.zarr.extract_zarr_variable_encoding( var, raise_on_invalid=True ) +@requires_zarr +@requires_fsspec +@pytest.mark.filterwarnings("ignore:deallocating CachingFileManager") +def test_open_fsspec(): + import fsspec + import zarr + + if not hasattr(zarr.storage, "FSStore") or not hasattr( + zarr.storage.FSStore, "getitems" + ): + pytest.skip("zarr too old") + + ds = open_dataset(os.path.join(os.path.dirname(__file__), "data", "example_1.nc")) + + m = fsspec.filesystem("memory") + mm = m.get_mapper("out1.zarr") + ds.to_zarr(mm) # old interface + ds0 = ds.copy() + ds0["time"] = ds.time + pd.to_timedelta("1 day") + mm = m.get_mapper("out2.zarr") + ds0.to_zarr(mm) # old interface + + # single dataset + url = "memory://out2.zarr" + ds2 = open_dataset(url, engine="zarr") + assert ds0 == ds2 + + # single dataset with caching + url = "simplecache::memory://out2.zarr" + ds2 = open_dataset(url, engine="zarr") + assert ds0 == ds2 + + # multi dataset + url = "memory://out*.zarr" + ds2 = open_mfdataset(url, engine="zarr") + assert xr.concat([ds, ds0], dim="time") == ds2 + + # multi dataset with caching + url = "simplecache::memory://out*.zarr" + ds2 = open_mfdataset(url, engine="zarr") + assert xr.concat([ds, ds0], dim="time") == ds2 + + @requires_h5netcdf +@requires_netCDF4 def test_load_single_value_h5netcdf(tmp_path): """Test that numeric single-element vector attributes are handled fine. At present (h5netcdf v0.8.1), the h5netcdf exposes single-valued numeric variable - attributes as arrays of length 1, as oppesed to scalars for the NetCDF4 + attributes as arrays of length 1, as opposed to scalars for the NetCDF4 backend. This was leading to a ValueError upon loading a single value from a file, see #4471. Test that loading causes no failure. """ @@ -4881,3 +5239,86 @@ def test_chunking_consintency(chunks, tmp_path): with xr.open_dataset(tmp_path / "test.nc", chunks=chunks) as actual: xr.testing.assert_chunks_equal(actual, expected) + + +def _check_guess_can_open_and_open(entrypoint, obj, engine, expected): + assert entrypoint.guess_can_open(obj) + with open_dataset(obj, engine=engine) as actual: + assert_identical(expected, actual) + + +@requires_netCDF4 +def test_netcdf4_entrypoint(tmp_path): + entrypoint = NetCDF4BackendEntrypoint() + ds = create_test_data() + + path = tmp_path / "foo" + ds.to_netcdf(path, format="netcdf3_classic") + _check_guess_can_open_and_open(entrypoint, path, engine="netcdf4", expected=ds) + _check_guess_can_open_and_open(entrypoint, str(path), engine="netcdf4", expected=ds) + + path = tmp_path / "bar" + ds.to_netcdf(path, format="netcdf4_classic") + _check_guess_can_open_and_open(entrypoint, path, engine="netcdf4", expected=ds) + _check_guess_can_open_and_open(entrypoint, str(path), engine="netcdf4", expected=ds) + + assert entrypoint.guess_can_open("http://something/remote") + assert entrypoint.guess_can_open("something-local.nc") + assert entrypoint.guess_can_open("something-local.nc4") + assert entrypoint.guess_can_open("something-local.cdf") + assert not entrypoint.guess_can_open("not-found-and-no-extension") + + path = tmp_path / "baz" + with open(path, "wb") as f: + f.write(b"not-a-netcdf-file") + assert not entrypoint.guess_can_open(path) + + +@requires_scipy +def test_scipy_entrypoint(tmp_path): + entrypoint = ScipyBackendEntrypoint() + ds = create_test_data() + + path = tmp_path / "foo" + ds.to_netcdf(path, engine="scipy") + _check_guess_can_open_and_open(entrypoint, path, engine="scipy", expected=ds) + _check_guess_can_open_and_open(entrypoint, str(path), engine="scipy", expected=ds) + with open(path, "rb") as f: + _check_guess_can_open_and_open(entrypoint, f, engine="scipy", expected=ds) + + contents = ds.to_netcdf(engine="scipy") + _check_guess_can_open_and_open(entrypoint, contents, engine="scipy", expected=ds) + _check_guess_can_open_and_open( + entrypoint, BytesIO(contents), engine="scipy", expected=ds + ) + + path = tmp_path / "foo.nc.gz" + with gzip.open(path, mode="wb") as f: + f.write(contents) + _check_guess_can_open_and_open(entrypoint, path, engine="scipy", expected=ds) + _check_guess_can_open_and_open(entrypoint, str(path), engine="scipy", expected=ds) + + assert entrypoint.guess_can_open("something-local.nc") + assert entrypoint.guess_can_open("something-local.nc.gz") + assert not entrypoint.guess_can_open("not-found-and-no-extension") + assert not entrypoint.guess_can_open(b"not-a-netcdf-file") + + +@requires_h5netcdf +def test_h5netcdf_entrypoint(tmp_path): + entrypoint = H5netcdfBackendEntrypoint() + ds = create_test_data() + + path = tmp_path / "foo" + ds.to_netcdf(path, engine="h5netcdf") + _check_guess_can_open_and_open(entrypoint, path, engine="h5netcdf", expected=ds) + _check_guess_can_open_and_open( + entrypoint, str(path), engine="h5netcdf", expected=ds + ) + with open(path, "rb") as f: + _check_guess_can_open_and_open(entrypoint, f, engine="h5netcdf", expected=ds) + + assert entrypoint.guess_can_open("something-local.nc") + assert entrypoint.guess_can_open("something-local.nc4") + assert entrypoint.guess_can_open("something-local.cdf") + assert not entrypoint.guess_can_open("not-found-and-no-extension") diff --git a/xarray/tests/test_backends_api.py b/xarray/tests/test_backends_api.py index d19f5aab585..4124d0d0b81 100644 --- a/xarray/tests/test_backends_api.py +++ b/xarray/tests/test_backends_api.py @@ -1,8 +1,9 @@ -import pytest +import numpy as np +import xarray as xr from xarray.backends.api import _get_default_engine -from . import requires_netCDF4, requires_scipy +from . import assert_identical, requires_netCDF4, requires_scipy @requires_netCDF4 @@ -14,8 +15,22 @@ def test__get_default_engine(): engine_gz = _get_default_engine("/example.gz") assert engine_gz == "scipy" - with pytest.raises(ValueError): - _get_default_engine("/example.grib") - engine_default = _get_default_engine("/example") assert engine_default == "netcdf4" + + +def test_custom_engine(): + expected = xr.Dataset( + dict(a=2 * np.arange(5)), coords=dict(x=("x", np.arange(5), dict(units="s"))) + ) + + class CustomBackend(xr.backends.BackendEntrypoint): + def open_dataset( + filename_or_obj, + drop_variables=None, + **kwargs, + ): + return expected.copy(deep=True) + + actual = xr.open_dataset("fake_filename", engine=CustomBackend) + assert_identical(expected, actual) diff --git a/xarray/tests/test_cftime_offsets.py b/xarray/tests/test_cftime_offsets.py index 3efcf8039c6..6d2d9907627 100644 --- a/xarray/tests/test_cftime_offsets.py +++ b/xarray/tests/test_cftime_offsets.py @@ -10,6 +10,8 @@ BaseCFTimeOffset, Day, Hour, + Microsecond, + Millisecond, Minute, MonthBegin, MonthEnd, @@ -24,22 +26,11 @@ to_cftime_datetime, to_offset, ) +from xarray.tests import _CFTIME_CALENDARS cftime = pytest.importorskip("cftime") -_CFTIME_CALENDARS = [ - "365_day", - "360_day", - "julian", - "all_leap", - "366_day", - "gregorian", - "proleptic_gregorian", - "standard", -] - - def _id_func(param): """Called on each parameter passed to pytest.mark.parametrize""" return str(param) @@ -181,6 +172,14 @@ def test_to_offset_offset_input(offset): ("2min", Minute(n=2)), ("S", Second()), ("2S", Second(n=2)), + ("L", Millisecond(n=1)), + ("2L", Millisecond(n=2)), + ("ms", Millisecond(n=1)), + ("2ms", Millisecond(n=2)), + ("U", Microsecond(n=1)), + ("2U", Microsecond(n=2)), + ("us", Microsecond(n=1)), + ("2us", Microsecond(n=2)), ], ids=_id_func, ) @@ -299,6 +298,8 @@ def test_to_cftime_datetime_error_type_error(): Hour(), Minute(), Second(), + Millisecond(), + Microsecond(), ] _EQ_TESTS_B = [ BaseCFTimeOffset(n=2), @@ -316,6 +317,8 @@ def test_to_cftime_datetime_error_type_error(): Hour(n=2), Minute(n=2), Second(n=2), + Millisecond(n=2), + Microsecond(n=2), ] @@ -340,6 +343,8 @@ def test_neq(a, b): Hour(n=2), Minute(n=2), Second(n=2), + Millisecond(n=2), + Microsecond(n=2), ] @@ -360,6 +365,8 @@ def test_eq(a, b): (Hour(), Hour(n=3)), (Minute(), Minute(n=3)), (Second(), Second(n=3)), + (Millisecond(), Millisecond(n=3)), + (Microsecond(), Microsecond(n=3)), ] @@ -387,6 +394,8 @@ def test_rmul(offset, expected): (Hour(), Hour(n=-1)), (Minute(), Minute(n=-1)), (Second(), Second(n=-1)), + (Millisecond(), Millisecond(n=-1)), + (Microsecond(), Microsecond(n=-1)), ], ids=_id_func, ) @@ -399,6 +408,8 @@ def test_neg(offset, expected): (Hour(n=2), (1, 1, 1, 2)), (Minute(n=2), (1, 1, 1, 0, 2)), (Second(n=2), (1, 1, 1, 0, 0, 2)), + (Millisecond(n=2), (1, 1, 1, 0, 0, 0, 2000)), + (Microsecond(n=2), (1, 1, 1, 0, 0, 0, 2)), ] @@ -427,6 +438,8 @@ def test_radd_sub_monthly(offset, expected_date_args, calendar): (Hour(n=2), (1, 1, 2, 22)), (Minute(n=2), (1, 1, 2, 23, 58)), (Second(n=2), (1, 1, 2, 23, 59, 58)), + (Millisecond(n=2), (1, 1, 2, 23, 59, 59, 998000)), + (Microsecond(n=2), (1, 1, 2, 23, 59, 59, 999998)), ], ids=_id_func, ) @@ -455,7 +468,7 @@ def test_minus_offset(a, b): @pytest.mark.parametrize( ("a", "b"), - list(zip(np.roll(_EQ_TESTS_A, 1), _EQ_TESTS_B)) + list(zip(np.roll(_EQ_TESTS_A, 1), _EQ_TESTS_B)) # type: ignore[arg-type] + [(YearEnd(month=1), YearEnd(month=2))], ids=_id_func, ) @@ -802,6 +815,8 @@ def test_add_quarter_end_onOffset( ((1, 1, 1), Hour(), True), ((1, 1, 1), Minute(), True), ((1, 1, 1), Second(), True), + ((1, 1, 1), Millisecond(), True), + ((1, 1, 1), Microsecond(), True), ], ids=_id_func, ) @@ -865,6 +880,8 @@ def test_onOffset_month_or_quarter_or_year_end( (Hour(), (1, 3, 2, 1, 1), (1, 3, 2, 1, 1)), (Minute(), (1, 3, 2, 1, 1, 1), (1, 3, 2, 1, 1, 1)), (Second(), (1, 3, 2, 1, 1, 1, 1), (1, 3, 2, 1, 1, 1, 1)), + (Millisecond(), (1, 3, 2, 1, 1, 1, 1000), (1, 3, 2, 1, 1, 1, 1000)), + (Microsecond(), (1, 3, 2, 1, 1, 1, 1), (1, 3, 2, 1, 1, 1, 1)), ], ids=_id_func, ) @@ -914,6 +931,8 @@ def test_rollforward(calendar, offset, initial_date_args, partial_expected_date_ (Hour(), (1, 3, 2, 1, 1), (1, 3, 2, 1, 1)), (Minute(), (1, 3, 2, 1, 1, 1), (1, 3, 2, 1, 1, 1)), (Second(), (1, 3, 2, 1, 1, 1, 1), (1, 3, 2, 1, 1, 1, 1)), + (Millisecond(), (1, 3, 2, 1, 1, 1, 1000), (1, 3, 2, 1, 1, 1, 1000)), + (Microsecond(), (1, 3, 2, 1, 1, 1, 1), (1, 3, 2, 1, 1, 1, 1)), ], ids=_id_func, ) diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index 71d6ffc8fff..725b5efee75 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -16,7 +16,7 @@ ) from xarray.tests import assert_array_equal, assert_identical -from . import raises_regex, requires_cftime, requires_cftime_1_1_0 +from . import requires_cftime from .test_coding_times import ( _ALL_CALENDARS, _NON_STANDARD_CALENDARS, @@ -244,7 +244,7 @@ def test_cftimeindex_dayofweek_accessor(index): assert_array_equal(result, expected) -@requires_cftime_1_1_0 +@requires_cftime def test_cftimeindex_days_in_month_accessor(index): result = index.days_in_month expected = [date.daysinmonth for date in index] @@ -340,7 +340,7 @@ def test_get_loc(date_type, index): result = index.get_loc("0001-02-01") assert result == slice(1, 2) - with raises_regex(KeyError, "1234"): + with pytest.raises(KeyError, match=r"1234"): index.get_loc("1234") @@ -438,7 +438,7 @@ def test_groupby(da): SEL_STRING_OR_LIST_TESTS = { "string": "0001", - "string-slice": slice("0001-01-01", "0001-12-30"), # type: ignore + "string-slice": slice("0001-01-01", "0001-12-30"), "bool-list": [True, True, False, False], } @@ -696,7 +696,7 @@ def test_concat_cftimeindex(date_type): ) da = xr.concat([da1, da2], dim="time") - assert isinstance(da.indexes["time"], CFTimeIndex) + assert isinstance(da.xindexes["time"].to_pandas_index(), CFTimeIndex) @requires_cftime @@ -916,7 +916,7 @@ def test_cftimeindex_calendar_property(calendar, expected): assert index.calendar == expected -@requires_cftime_1_1_0 +@requires_cftime @pytest.mark.parametrize( ("calendar", "expected"), [ @@ -936,7 +936,7 @@ def test_cftimeindex_calendar_repr(calendar, expected): assert "2000-01-01 00:00:00, 2000-01-02 00:00:00" in repr_str -@requires_cftime_1_1_0 +@requires_cftime @pytest.mark.parametrize("periods", [2, 40]) def test_cftimeindex_periods_repr(periods): """Test that cftimeindex has periods property in repr.""" @@ -945,7 +945,7 @@ def test_cftimeindex_periods_repr(periods): assert f" length={periods}" in repr_str -@requires_cftime_1_1_0 +@requires_cftime @pytest.mark.parametrize("calendar", ["noleap", "360_day", "standard"]) @pytest.mark.parametrize("freq", ["D", "H"]) def test_cftimeindex_freq_in_repr(freq, calendar): @@ -955,7 +955,7 @@ def test_cftimeindex_freq_in_repr(freq, calendar): assert f", freq='{freq}'" in repr_str -@requires_cftime_1_1_0 +@requires_cftime @pytest.mark.parametrize( "periods,expected", [ @@ -995,7 +995,7 @@ def test_cftimeindex_repr_formatting(periods, expected): assert expected == repr(index) -@requires_cftime_1_1_0 +@requires_cftime @pytest.mark.parametrize("display_width", [40, 80, 100]) @pytest.mark.parametrize("periods", [2, 3, 4, 100, 101]) def test_cftimeindex_repr_formatting_width(periods, display_width): @@ -1013,7 +1013,7 @@ def test_cftimeindex_repr_formatting_width(periods, display_width): assert s[:len_intro_str] == " " * len_intro_str -@requires_cftime_1_1_0 +@requires_cftime @pytest.mark.parametrize("periods", [22, 50, 100]) def test_cftimeindex_repr_101_shorter(periods): index_101 = xr.cftime_range(start="2000", periods=101) @@ -1187,7 +1187,7 @@ def test_asi8_distant_date(): np.testing.assert_array_equal(result, expected) -@requires_cftime_1_1_0 +@requires_cftime def test_infer_freq_valid_types(): cf_indx = xr.cftime_range("2000-01-01", periods=3, freq="D") assert xr.infer_freq(cf_indx) == "D" @@ -1202,7 +1202,7 @@ def test_infer_freq_valid_types(): assert xr.infer_freq(xr.DataArray(pd_td_indx)) == "D" -@requires_cftime_1_1_0 +@requires_cftime def test_infer_freq_invalid_inputs(): # Non-datetime DataArray with pytest.raises(ValueError, match="must contain datetime-like objects"): @@ -1231,7 +1231,7 @@ def test_infer_freq_invalid_inputs(): assert xr.infer_freq(indx[np.array([0, 1, 3])]) is None -@requires_cftime_1_1_0 +@requires_cftime @pytest.mark.parametrize( "freq", [ diff --git a/xarray/tests/test_cftimeindex_resample.py b/xarray/tests/test_cftimeindex_resample.py index c4f32795b59..526f3fc30c1 100644 --- a/xarray/tests/test_cftimeindex_resample.py +++ b/xarray/tests/test_cftimeindex_resample.py @@ -99,7 +99,10 @@ def test_resample(freqs, closed, label, base): ) .mean() ) - da_cftime["time"] = da_cftime.indexes["time"].to_datetimeindex() + # TODO (benbovy - flexible indexes): update when CFTimeIndex is a xarray Index subclass + da_cftime["time"] = ( + da_cftime.xindexes["time"].to_pandas_index().to_datetimeindex() + ) xr.testing.assert_identical(da_cftime, da_datetime) @@ -145,5 +148,6 @@ def test_calendars(calendar): .resample(time=freq, closed=closed, label=label, base=base, loffset=loffset) .mean() ) - da_cftime["time"] = da_cftime.indexes["time"].to_datetimeindex() + # TODO (benbovy - flexible indexes): update when CFTimeIndex is a xarray Index subclass + da_cftime["time"] = da_cftime.xindexes["time"].to_pandas_index().to_datetimeindex() xr.testing.assert_identical(da_cftime, da_datetime) diff --git a/xarray/tests/test_coarsen.py b/xarray/tests/test_coarsen.py new file mode 100644 index 00000000000..278a961166f --- /dev/null +++ b/xarray/tests/test_coarsen.py @@ -0,0 +1,320 @@ +import numpy as np +import pandas as pd +import pytest + +import xarray as xr +from xarray import DataArray, Dataset, set_options + +from . import ( + assert_allclose, + assert_equal, + assert_identical, + has_dask, + raise_if_dask_computes, + requires_cftime, +) +from .test_dataarray import da +from .test_dataset import ds + + +def test_coarsen_absent_dims_error(ds): + with pytest.raises(ValueError, match=r"not found in Dataset."): + ds.coarsen(foo=2) + + +@pytest.mark.parametrize("dask", [True, False]) +@pytest.mark.parametrize(("boundary", "side"), [("trim", "left"), ("pad", "right")]) +def test_coarsen_dataset(ds, dask, boundary, side): + if dask and has_dask: + ds = ds.chunk({"x": 4}) + + actual = ds.coarsen(time=2, x=3, boundary=boundary, side=side).max() + assert_equal( + actual["z1"], ds["z1"].coarsen(x=3, boundary=boundary, side=side).max() + ) + # coordinate should be mean by default + assert_equal( + actual["time"], ds["time"].coarsen(time=2, boundary=boundary, side=side).mean() + ) + + +@pytest.mark.parametrize("dask", [True, False]) +def test_coarsen_coords(ds, dask): + if dask and has_dask: + ds = ds.chunk({"x": 4}) + + # check if coord_func works + actual = ds.coarsen(time=2, x=3, boundary="trim", coord_func={"time": "max"}).max() + assert_equal(actual["z1"], ds["z1"].coarsen(x=3, boundary="trim").max()) + assert_equal(actual["time"], ds["time"].coarsen(time=2, boundary="trim").max()) + + # raise if exact + with pytest.raises(ValueError): + ds.coarsen(x=3).mean() + # should be no error + ds.isel(x=slice(0, 3 * (len(ds["x"]) // 3))).coarsen(x=3).mean() + + # working test with pd.time + da = xr.DataArray( + np.linspace(0, 365, num=364), + dims="time", + coords={"time": pd.date_range("15/12/1999", periods=364)}, + ) + actual = da.coarsen(time=2).mean() + + +@requires_cftime +def test_coarsen_coords_cftime(): + times = xr.cftime_range("2000", periods=6) + da = xr.DataArray(range(6), [("time", times)]) + actual = da.coarsen(time=3).mean() + expected_times = xr.cftime_range("2000-01-02", freq="3D", periods=2) + np.testing.assert_array_equal(actual.time, expected_times) + + +@pytest.mark.parametrize( + "funcname, argument", + [ + ("reduce", (np.mean,)), + ("mean", ()), + ], +) +def test_coarsen_keep_attrs(funcname, argument): + global_attrs = {"units": "test", "long_name": "testing"} + da_attrs = {"da_attr": "test"} + attrs_coords = {"attrs_coords": "test"} + da_not_coarsend_attrs = {"da_not_coarsend_attr": "test"} + + data = np.linspace(10, 15, 100) + coords = np.linspace(1, 10, 100) + + ds = Dataset( + data_vars={ + "da": ("coord", data, da_attrs), + "da_not_coarsend": ("no_coord", data, da_not_coarsend_attrs), + }, + coords={"coord": ("coord", coords, attrs_coords)}, + attrs=global_attrs, + ) + + # attrs are now kept per default + func = getattr(ds.coarsen(dim={"coord": 5}), funcname) + result = func(*argument) + assert result.attrs == global_attrs + assert result.da.attrs == da_attrs + assert result.da_not_coarsend.attrs == da_not_coarsend_attrs + assert result.coord.attrs == attrs_coords + assert result.da.name == "da" + assert result.da_not_coarsend.name == "da_not_coarsend" + + # discard attrs + func = getattr(ds.coarsen(dim={"coord": 5}), funcname) + result = func(*argument, keep_attrs=False) + assert result.attrs == {} + assert result.da.attrs == {} + assert result.da_not_coarsend.attrs == {} + assert result.coord.attrs == {} + assert result.da.name == "da" + assert result.da_not_coarsend.name == "da_not_coarsend" + + # test discard attrs using global option + func = getattr(ds.coarsen(dim={"coord": 5}), funcname) + with set_options(keep_attrs=False): + result = func(*argument) + + assert result.attrs == {} + assert result.da.attrs == {} + assert result.da_not_coarsend.attrs == {} + assert result.coord.attrs == {} + assert result.da.name == "da" + assert result.da_not_coarsend.name == "da_not_coarsend" + + # keyword takes precedence over global option + func = getattr(ds.coarsen(dim={"coord": 5}), funcname) + with set_options(keep_attrs=False): + result = func(*argument, keep_attrs=True) + + assert result.attrs == global_attrs + assert result.da.attrs == da_attrs + assert result.da_not_coarsend.attrs == da_not_coarsend_attrs + assert result.coord.attrs == attrs_coords + assert result.da.name == "da" + assert result.da_not_coarsend.name == "da_not_coarsend" + + func = getattr(ds.coarsen(dim={"coord": 5}), funcname) + with set_options(keep_attrs=True): + result = func(*argument, keep_attrs=False) + + assert result.attrs == {} + assert result.da.attrs == {} + assert result.da_not_coarsend.attrs == {} + assert result.coord.attrs == {} + assert result.da.name == "da" + assert result.da_not_coarsend.name == "da_not_coarsend" + + +@pytest.mark.slow +@pytest.mark.parametrize("ds", (1, 2), indirect=True) +@pytest.mark.parametrize("window", (1, 2, 3, 4)) +@pytest.mark.parametrize("name", ("sum", "mean", "std", "var", "min", "max", "median")) +def test_coarsen_reduce(ds, window, name): + # Use boundary="trim" to accomodate all window sizes used in tests + coarsen_obj = ds.coarsen(time=window, boundary="trim") + + # add nan prefix to numpy methods to get similar behavior as bottleneck + actual = coarsen_obj.reduce(getattr(np, f"nan{name}")) + expected = getattr(coarsen_obj, name)() + assert_allclose(actual, expected) + + # make sure the order of data_var are not changed. + assert list(ds.data_vars.keys()) == list(actual.data_vars.keys()) + + # Make sure the dimension order is restored + for key, src_var in ds.data_vars.items(): + assert src_var.dims == actual[key].dims + + +@pytest.mark.parametrize( + "funcname, argument", + [ + ("reduce", (np.mean,)), + ("mean", ()), + ], +) +def test_coarsen_da_keep_attrs(funcname, argument): + attrs_da = {"da_attr": "test"} + attrs_coords = {"attrs_coords": "test"} + + data = np.linspace(10, 15, 100) + coords = np.linspace(1, 10, 100) + + da = DataArray( + data, + dims=("coord"), + coords={"coord": ("coord", coords, attrs_coords)}, + attrs=attrs_da, + name="name", + ) + + # attrs are now kept per default + func = getattr(da.coarsen(dim={"coord": 5}), funcname) + result = func(*argument) + assert result.attrs == attrs_da + da.coord.attrs == attrs_coords + assert result.name == "name" + + # discard attrs + func = getattr(da.coarsen(dim={"coord": 5}), funcname) + result = func(*argument, keep_attrs=False) + assert result.attrs == {} + da.coord.attrs == {} + assert result.name == "name" + + # test discard attrs using global option + func = getattr(da.coarsen(dim={"coord": 5}), funcname) + with set_options(keep_attrs=False): + result = func(*argument) + assert result.attrs == {} + da.coord.attrs == {} + assert result.name == "name" + + # keyword takes precedence over global option + func = getattr(da.coarsen(dim={"coord": 5}), funcname) + with set_options(keep_attrs=False): + result = func(*argument, keep_attrs=True) + assert result.attrs == attrs_da + da.coord.attrs == {} + assert result.name == "name" + + func = getattr(da.coarsen(dim={"coord": 5}), funcname) + with set_options(keep_attrs=True): + result = func(*argument, keep_attrs=False) + assert result.attrs == {} + da.coord.attrs == {} + assert result.name == "name" + + +@pytest.mark.parametrize("da", (1, 2), indirect=True) +@pytest.mark.parametrize("window", (1, 2, 3, 4)) +@pytest.mark.parametrize("name", ("sum", "mean", "std", "max")) +def test_coarsen_da_reduce(da, window, name): + if da.isnull().sum() > 1 and window == 1: + pytest.skip("These parameters lead to all-NaN slices") + + # Use boundary="trim" to accomodate all window sizes used in tests + coarsen_obj = da.coarsen(time=window, boundary="trim") + + # add nan prefix to numpy methods to get similar # behavior as bottleneck + actual = coarsen_obj.reduce(getattr(np, f"nan{name}")) + expected = getattr(coarsen_obj, name)() + assert_allclose(actual, expected) + + +@pytest.mark.parametrize("dask", [True, False]) +def test_coarsen_construct(dask): + + ds = Dataset( + { + "vart": ("time", np.arange(48), {"a": "b"}), + "varx": ("x", np.arange(10), {"a": "b"}), + "vartx": (("x", "time"), np.arange(480).reshape(10, 48), {"a": "b"}), + "vary": ("y", np.arange(12)), + }, + coords={"time": np.arange(48), "y": np.arange(12)}, + attrs={"foo": "bar"}, + ) + + if dask and has_dask: + ds = ds.chunk({"x": 4, "time": 10}) + + expected = xr.Dataset(attrs={"foo": "bar"}) + expected["vart"] = (("year", "month"), ds.vart.data.reshape((-1, 12)), {"a": "b"}) + expected["varx"] = (("x", "x_reshaped"), ds.varx.data.reshape((-1, 5)), {"a": "b"}) + expected["vartx"] = ( + ("x", "x_reshaped", "year", "month"), + ds.vartx.data.reshape(2, 5, 4, 12), + {"a": "b"}, + ) + expected["vary"] = ds.vary + expected.coords["time"] = (("year", "month"), ds.time.data.reshape((-1, 12))) + + with raise_if_dask_computes(): + actual = ds.coarsen(time=12, x=5).construct( + {"time": ("year", "month"), "x": ("x", "x_reshaped")} + ) + assert_identical(actual, expected) + + with raise_if_dask_computes(): + actual = ds.coarsen(time=12, x=5).construct( + time=("year", "month"), x=("x", "x_reshaped") + ) + assert_identical(actual, expected) + + with raise_if_dask_computes(): + actual = ds.coarsen(time=12, x=5).construct( + {"time": ("year", "month"), "x": ("x", "x_reshaped")}, keep_attrs=False + ) + for var in actual: + assert actual[var].attrs == {} + assert actual.attrs == {} + + with raise_if_dask_computes(): + actual = ds.vartx.coarsen(time=12, x=5).construct( + {"time": ("year", "month"), "x": ("x", "x_reshaped")} + ) + assert_identical(actual, expected["vartx"]) + + with pytest.raises(ValueError): + ds.coarsen(time=12).construct(foo="bar") + + with pytest.raises(ValueError): + ds.coarsen(time=12, x=2).construct(time=("year", "month")) + + with pytest.raises(ValueError): + ds.coarsen(time=12).construct() + + with pytest.raises(ValueError): + ds.coarsen(time=12).construct(time="bar") + + with pytest.raises(ValueError): + ds.coarsen(time=12).construct(time=("bar",)) diff --git a/xarray/tests/test_coding.py b/xarray/tests/test_coding.py index e0df7782aa7..839f2fd1f2e 100644 --- a/xarray/tests/test_coding.py +++ b/xarray/tests/test_coding.py @@ -117,3 +117,31 @@ def test_scaling_offset_as_list(scale_factor, add_offset): encoded = coder.encode(original) roundtripped = coder.decode(encoded) assert_allclose(original, roundtripped) + + +@pytest.mark.parametrize("bits", [1, 2, 4, 8]) +def test_decode_unsigned_from_signed(bits): + unsigned_dtype = np.dtype(f"u{bits}") + signed_dtype = np.dtype(f"i{bits}") + original_values = np.array([np.iinfo(unsigned_dtype).max], dtype=unsigned_dtype) + encoded = xr.Variable( + ("x",), original_values.astype(signed_dtype), attrs={"_Unsigned": "true"} + ) + coder = variables.UnsignedIntegerCoder() + decoded = coder.decode(encoded) + assert decoded.dtype == unsigned_dtype + assert decoded.values == original_values + + +@pytest.mark.parametrize("bits", [1, 2, 4, 8]) +def test_decode_signed_from_unsigned(bits): + unsigned_dtype = np.dtype(f"u{bits}") + signed_dtype = np.dtype(f"i{bits}") + original_values = np.array([-1], dtype=signed_dtype) + encoded = xr.Variable( + ("x",), original_values.astype(unsigned_dtype), attrs={"_Unsigned": "false"} + ) + coder = variables.UnsignedIntegerCoder() + decoded = coder.decode(encoded) + assert decoded.dtype == signed_dtype + assert decoded.values == original_values diff --git a/xarray/tests/test_coding_strings.py b/xarray/tests/test_coding_strings.py index c9d10ba4eb0..800e91d9473 100644 --- a/xarray/tests/test_coding_strings.py +++ b/xarray/tests/test_coding_strings.py @@ -7,13 +7,7 @@ from xarray.coding import strings from xarray.core import indexing -from . import ( - IndexerMaker, - assert_array_equal, - assert_identical, - raises_regex, - requires_dask, -) +from . import IndexerMaker, assert_array_equal, assert_identical, requires_dask with suppress(ImportError): import dask.array as da @@ -210,7 +204,7 @@ def test_char_to_bytes_dask(): assert actual.dtype == "S3" assert_array_equal(np.array(actual), expected) - with raises_regex(ValueError, "stacked dask character array"): + with pytest.raises(ValueError, match=r"stacked dask character array"): strings.char_to_bytes(array.rechunk(1)) diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index dfd558f737e..f0882afe367 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -1,4 +1,5 @@ import warnings +from datetime import timedelta from itertools import product import numpy as np @@ -6,8 +7,17 @@ import pytest from pandas.errors import OutOfBoundsDatetime -from xarray import DataArray, Dataset, Variable, coding, conventions, decode_cf +from xarray import ( + DataArray, + Dataset, + Variable, + cftime_range, + coding, + conventions, + decode_cf, +) from xarray.coding.times import ( + _encode_datetime_with_cftime, cftime_to_nptime, decode_cf_datetime, encode_cf_datetime, @@ -16,9 +26,17 @@ from xarray.coding.variables import SerializationWarning from xarray.conventions import _update_bounds_attributes, cf_encoder from xarray.core.common import contains_cftime_datetimes -from xarray.testing import assert_equal - -from . import arm_xfail, assert_array_equal, has_cftime, requires_cftime, requires_dask +from xarray.testing import assert_equal, assert_identical + +from . import ( + arm_xfail, + assert_array_equal, + has_cftime, + has_cftime_1_4_1, + requires_cftime, + requires_cftime_1_4_1, + requires_dask, +) _NON_STANDARD_CALENDARS_SET = { "noleap", @@ -61,6 +79,9 @@ (0, "microseconds since 2000-01-01T00:00:00"), (np.int32(788961600), "seconds since 1981-01-01"), # GH2002 (12300 + np.arange(5), "hour since 1680-01-01 00:00:00.500000"), + (164375, "days since 1850-01-01 00:00:00"), + (164374.5, "days since 1850-01-01 00:00:00"), + ([164374.5, 168360.5], "days since 1850-01-01 00:00:00"), ] _CF_DATETIME_TESTS = [ num_dates_units + (calendar,) @@ -803,6 +824,15 @@ def test_encode_cf_datetime_overflow(shape): np.testing.assert_array_equal(dates, roundtrip) +def test_encode_expected_failures(): + + dates = pd.date_range("2000", periods=3) + with pytest.raises(ValueError, match="invalid time units"): + encode_cf_datetime(dates, units="days after 2000-01-01") + with pytest.raises(ValueError, match="invalid reference date"): + encode_cf_datetime(dates, units="days since NO_YEAR") + + def test_encode_cf_datetime_pandas_min(): # GH 2623 dates = pd.date_range("2000", periods=3) @@ -972,8 +1002,13 @@ def test_decode_ambiguous_time_warns(calendar): @pytest.mark.parametrize("encoding_units", FREQUENCIES_TO_ENCODING_UNITS.values()) @pytest.mark.parametrize("freq", FREQUENCIES_TO_ENCODING_UNITS.keys()) -def test_encode_cf_datetime_defaults_to_correct_dtype(encoding_units, freq): - times = pd.date_range("2000", periods=3, freq=freq) +@pytest.mark.parametrize("date_range", [pd.date_range, cftime_range]) +def test_encode_cf_datetime_defaults_to_correct_dtype(encoding_units, freq, date_range): + if not has_cftime_1_4_1 and date_range == cftime_range: + pytest.skip("Test requires cftime 1.4.1.") + if (freq == "N" or encoding_units == "nanoseconds") and date_range == cftime_range: + pytest.skip("Nanosecond frequency is not valid for cftime dates.") + times = date_range("2000", periods=3, freq=freq) units = f"{encoding_units} since 2000-01-01" encoded, _, _ = coding.times.encode_cf_datetime(times, units) @@ -986,7 +1021,7 @@ def test_encode_cf_datetime_defaults_to_correct_dtype(encoding_units, freq): @pytest.mark.parametrize("freq", FREQUENCIES_TO_ENCODING_UNITS.keys()) -def test_encode_decode_roundtrip(freq): +def test_encode_decode_roundtrip_datetime64(freq): # See GH 4045. Prior to GH 4684 this test would fail for frequencies of # "S", "L", "U", and "N". initial_time = pd.date_range("1678-01-01", periods=1) @@ -995,3 +1030,51 @@ def test_encode_decode_roundtrip(freq): encoded = conventions.encode_cf_variable(variable) decoded = conventions.decode_cf_variable("time", encoded) assert_equal(variable, decoded) + + +@requires_cftime_1_4_1 +@pytest.mark.parametrize("freq", ["U", "L", "S", "T", "H", "D"]) +def test_encode_decode_roundtrip_cftime(freq): + initial_time = cftime_range("0001", periods=1) + times = initial_time.append( + cftime_range("0001", periods=2, freq=freq) + timedelta(days=291000 * 365) + ) + variable = Variable(["time"], times) + encoded = conventions.encode_cf_variable(variable) + decoded = conventions.decode_cf_variable("time", encoded, use_cftime=True) + assert_equal(variable, decoded) + + +@requires_cftime +def test__encode_datetime_with_cftime(): + # See GH 4870. cftime versions > 1.4.0 required us to adapt the + # way _encode_datetime_with_cftime was written. + import cftime + + calendar = "gregorian" + times = cftime.num2date([0, 1], "hours since 2000-01-01", calendar) + + encoding_units = "days since 2000-01-01" + expected = cftime.date2num(times, encoding_units, calendar) + result = _encode_datetime_with_cftime(times, encoding_units, calendar) + np.testing.assert_equal(result, expected) + + +@pytest.mark.parametrize("calendar", ["gregorian", "Gregorian", "GREGORIAN"]) +def test_decode_encode_roundtrip_with_non_lowercase_letters(calendar): + # See GH 5093. + times = [0, 1] + units = "days since 2000-01-01" + attrs = {"calendar": calendar, "units": units} + variable = Variable(["time"], times, attrs) + decoded = conventions.decode_cf_variable("time", variable) + encoded = conventions.encode_cf_variable(decoded) + + # Previously this would erroneously be an array of cftime.datetime + # objects. We check here that it is decoded properly to np.datetime64. + assert np.issubdtype(decoded.dtype, np.datetime64) + + # Use assert_identical to ensure that the calendar attribute maintained its + # original form throughout the roundtripping process, uppercase letters and + # all. + assert_identical(variable, encoded) diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index 109b78f05a9..3ca964b94e1 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -1,10 +1,18 @@ from datetime import datetime +from distutils.version import LooseVersion from itertools import product import numpy as np import pytest -from xarray import DataArray, Dataset, combine_by_coords, combine_nested, concat +from xarray import ( + DataArray, + Dataset, + MergeError, + combine_by_coords, + combine_nested, + concat, +) from xarray.core import dtypes from xarray.core.combine import ( _check_shape_tile_ids, @@ -15,7 +23,7 @@ _new_tile_id, ) -from . import assert_equal, assert_identical, raises_regex, requires_cftime +from . import assert_equal, assert_identical, requires_cftime from .test_dataset import create_test_data @@ -161,15 +169,15 @@ def test_2d(self): def test_no_dimension_coords(self): ds0 = Dataset({"foo": ("x", [0, 1])}) ds1 = Dataset({"foo": ("x", [2, 3])}) - with raises_regex(ValueError, "Could not find any dimension"): + with pytest.raises(ValueError, match=r"Could not find any dimension"): _infer_concat_order_from_coords([ds1, ds0]) def test_coord_not_monotonic(self): ds0 = Dataset({"x": [0, 1]}) ds1 = Dataset({"x": [3, 2]}) - with raises_regex( + with pytest.raises( ValueError, - "Coordinate variable x is neither monotonically increasing nor", + match=r"Coordinate variable x is neither monotonically increasing nor", ): _infer_concat_order_from_coords([ds1, ds0]) @@ -319,13 +327,17 @@ class TestCheckShapeTileIDs: def test_check_depths(self): ds = create_test_data(0) combined_tile_ids = {(0,): ds, (0, 1): ds} - with raises_regex(ValueError, "sub-lists do not have consistent depths"): + with pytest.raises( + ValueError, match=r"sub-lists do not have consistent depths" + ): _check_shape_tile_ids(combined_tile_ids) def test_check_lengths(self): ds = create_test_data(0) combined_tile_ids = {(0, 0): ds, (0, 1): ds, (0, 2): ds, (1, 0): ds, (1, 1): ds} - with raises_regex(ValueError, "sub-lists do not have consistent lengths"): + with pytest.raises( + ValueError, match=r"sub-lists do not have consistent lengths" + ): _check_shape_tile_ids(combined_tile_ids) @@ -379,7 +391,7 @@ def test_combine_nested_join(self, join, expected): def test_combine_nested_join_exact(self): objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [1], "y": [1]})] - with raises_regex(ValueError, "indexes along dimension"): + with pytest.raises(ValueError, match=r"indexes along dimension"): combine_nested(objs, concat_dim="x", join="exact") def test_empty_input(self): @@ -471,7 +483,8 @@ def test_concat_name_symmetry(self): assert_identical(x_first, y_first) def test_concat_one_dim_merge_another(self): - data = create_test_data() + data = create_test_data(add_attrs=False) + data1 = data.copy(deep=True) data2 = data.copy(deep=True) @@ -497,7 +510,7 @@ def test_auto_combine_2d(self): assert_equal(result, expected) def test_auto_combine_2d_combine_attrs_kwarg(self): - ds = create_test_data + ds = lambda x: create_test_data(x, add_attrs=False) partway1 = concat([ds(0), ds(3)], dim="dim1") partway2 = concat([ds(1), ds(4)], dim="dim1") @@ -518,6 +531,9 @@ def test_auto_combine_2d_combine_attrs_kwarg(self): } expected_dict["override"] = expected.copy(deep=True) expected_dict["override"].attrs = {"a": 1} + f = lambda attrs, context: attrs[0] + expected_dict[f] = expected.copy(deep=True) + expected_dict[f].attrs = f([{"a": 1}], None) datasets = [[ds(0), ds(1), ds(2)], [ds(3), ds(4), ds(5)]] @@ -528,7 +544,7 @@ def test_auto_combine_2d_combine_attrs_kwarg(self): datasets[1][1].attrs = {"a": 1, "e": 5} datasets[1][2].attrs = {"a": 1, "f": 6} - with raises_regex(ValueError, "combine_attrs='identical'"): + with pytest.raises(ValueError, match=r"combine_attrs='identical'"): result = combine_nested( datasets, concat_dim=["dim1", "dim2"], combine_attrs="identical" ) @@ -556,15 +572,19 @@ def test_invalid_hypercube_input(self): ds = create_test_data datasets = [[ds(0), ds(1), ds(2)], [ds(3), ds(4)]] - with raises_regex(ValueError, "sub-lists do not have consistent lengths"): + with pytest.raises( + ValueError, match=r"sub-lists do not have consistent lengths" + ): combine_nested(datasets, concat_dim=["dim1", "dim2"]) datasets = [[ds(0), ds(1)], [[ds(3), ds(4)]]] - with raises_regex(ValueError, "sub-lists do not have consistent depths"): + with pytest.raises( + ValueError, match=r"sub-lists do not have consistent depths" + ): combine_nested(datasets, concat_dim=["dim1", "dim2"]) datasets = [[ds(0), ds(1)], [ds(3), ds(4)]] - with raises_regex(ValueError, "concat_dims has length"): + with pytest.raises(ValueError, match=r"concat_dims has length"): combine_nested(datasets, concat_dim=["dim1"]) def test_merge_one_dim_concat_another(self): @@ -626,6 +646,47 @@ def test_combine_nested_fill_value(self, fill_value): actual = combine_nested(datasets, concat_dim="t", fill_value=fill_value) assert_identical(expected, actual) + def test_combine_nested_unnamed_data_arrays(self): + unnamed_array = DataArray(data=[1.0, 2.0], coords={"x": [0, 1]}, dims="x") + + actual = combine_nested([unnamed_array], concat_dim="x") + expected = unnamed_array + assert_identical(expected, actual) + + unnamed_array1 = DataArray(data=[1.0, 2.0], coords={"x": [0, 1]}, dims="x") + unnamed_array2 = DataArray(data=[3.0, 4.0], coords={"x": [2, 3]}, dims="x") + + actual = combine_nested([unnamed_array1, unnamed_array2], concat_dim="x") + expected = DataArray( + data=[1.0, 2.0, 3.0, 4.0], coords={"x": [0, 1, 2, 3]}, dims="x" + ) + assert_identical(expected, actual) + + da1 = DataArray(data=[[0.0]], coords={"x": [0], "y": [0]}, dims=["x", "y"]) + da2 = DataArray(data=[[1.0]], coords={"x": [0], "y": [1]}, dims=["x", "y"]) + da3 = DataArray(data=[[2.0]], coords={"x": [1], "y": [0]}, dims=["x", "y"]) + da4 = DataArray(data=[[3.0]], coords={"x": [1], "y": [1]}, dims=["x", "y"]) + objs = [[da1, da2], [da3, da4]] + + expected = DataArray( + data=[[0.0, 1.0], [2.0, 3.0]], + coords={"x": [0, 1], "y": [0, 1]}, + dims=["x", "y"], + ) + actual = combine_nested(objs, concat_dim=["x", "y"]) + assert_identical(expected, actual) + + # TODO aijams - Determine if this test is appropriate. + def test_nested_combine_mixed_datasets_arrays(self): + objs = [ + DataArray([0, 1], dims=("x"), coords=({"x": [0, 1]})), + Dataset({"x": [2, 3]}), + ] + with pytest.raises( + ValueError, match=r"Can't combine datasets with unnamed arrays." + ): + combine_nested(objs, "x") + class TestCombineAuto: def test_combine_by_coords(self): @@ -657,15 +718,28 @@ def test_combine_by_coords(self): assert_equal(actual, expected) objs = [Dataset({"x": 0}), Dataset({"x": 1})] - with raises_regex(ValueError, "Could not find any dimension coordinates"): + with pytest.raises( + ValueError, match=r"Could not find any dimension coordinates" + ): combine_by_coords(objs) objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [0]})] - with raises_regex(ValueError, "Every dimension needs a coordinate"): + with pytest.raises(ValueError, match=r"Every dimension needs a coordinate"): combine_by_coords(objs) - def test_empty_input(self): - assert_identical(Dataset(), combine_by_coords([])) + def test_empty_input(self): + assert_identical(Dataset(), combine_by_coords([])) + + def test_combine_coords_mixed_datasets_arrays(self): + objs = [ + DataArray([0, 1], dims=("x"), coords=({"x": [0, 1]})), + Dataset({"x": [2, 3]}), + ] + with pytest.raises( + ValueError, + match=r"Can't automatically combine datasets with unnamed arrays.", + ): + combine_by_coords(objs) @pytest.mark.parametrize( "join, expected", @@ -683,7 +757,7 @@ def test_combine_coords_join(self, join, expected): def test_combine_coords_join_exact(self): objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [1], "y": [1]})] - with raises_regex(ValueError, "indexes along dimension"): + with pytest.raises(ValueError, match=r"indexes along dimension"): combine_nested(objs, concat_dim="x", join="exact") @pytest.mark.parametrize( @@ -695,6 +769,10 @@ def test_combine_coords_join_exact(self): Dataset({"x": [0, 1], "y": [0, 1]}, attrs={"a": 1, "b": 2}), ), ("override", Dataset({"x": [0, 1], "y": [0, 1]}, attrs={"a": 1})), + ( + lambda attrs, context: attrs[1], + Dataset({"x": [0, 1], "y": [0, 1]}, attrs={"a": 1, "b": 2}), + ), ], ) def test_combine_coords_combine_attrs(self, combine_attrs, expected): @@ -709,7 +787,7 @@ def test_combine_coords_combine_attrs(self, combine_attrs, expected): if combine_attrs == "no_conflicts": objs[1].attrs["a"] = 2 - with raises_regex(ValueError, "combine_attrs='no_conflicts'"): + with pytest.raises(ValueError, match=r"combine_attrs='no_conflicts'"): actual = combine_nested( objs, concat_dim="x", join="outer", combine_attrs=combine_attrs ) @@ -727,11 +805,158 @@ def test_combine_coords_combine_attrs_identical(self): objs[1].attrs["b"] = 2 - with raises_regex(ValueError, "combine_attrs='identical'"): + with pytest.raises(ValueError, match=r"combine_attrs='identical'"): actual = combine_nested( objs, concat_dim="x", join="outer", combine_attrs="identical" ) + def test_combine_nested_combine_attrs_drop_conflicts(self): + objs = [ + Dataset({"x": [0], "y": [0]}, attrs={"a": 1, "b": 2, "c": 3}), + Dataset({"x": [1], "y": [1]}, attrs={"a": 1, "b": 0, "d": 3}), + ] + expected = Dataset({"x": [0, 1], "y": [0, 1]}, attrs={"a": 1, "c": 3, "d": 3}) + actual = combine_nested( + objs, concat_dim="x", join="outer", combine_attrs="drop_conflicts" + ) + assert_identical(expected, actual) + + @pytest.mark.parametrize( + "combine_attrs, attrs1, attrs2, expected_attrs, expect_exception", + [ + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 1, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + False, + ), + ("no_conflicts", {"a": 1, "b": 2}, {}, {"a": 1, "b": 2}, False), + ("no_conflicts", {}, {"a": 1, "c": 3}, {"a": 1, "c": 3}, False), + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 4, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + True, + ), + ("drop", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "b": 2}, {"a": 1, "b": 2}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {"a": 1, "b": 2}, True), + ( + "override", + {"a": 1, "b": 2}, + {"a": 4, "b": 5, "c": 3}, + {"a": 1, "b": 2}, + False, + ), + ( + "drop_conflicts", + {"a": 1, "b": 2, "c": 3}, + {"b": 1, "c": 3, "d": 4}, + {"a": 1, "c": 3, "d": 4}, + False, + ), + ], + ) + def test_combine_nested_combine_attrs_variables( + self, combine_attrs, attrs1, attrs2, expected_attrs, expect_exception + ): + """check that combine_attrs is used on data variables and coords""" + data1 = Dataset( + { + "a": ("x", [1, 2], attrs1), + "b": ("x", [3, -1], attrs1), + "x": ("x", [0, 1], attrs1), + } + ) + data2 = Dataset( + { + "a": ("x", [2, 3], attrs2), + "b": ("x", [-2, 1], attrs2), + "x": ("x", [2, 3], attrs2), + } + ) + + if expect_exception: + with pytest.raises(MergeError, match="combine_attrs"): + combine_by_coords([data1, data2], combine_attrs=combine_attrs) + else: + actual = combine_by_coords([data1, data2], combine_attrs=combine_attrs) + expected = Dataset( + { + "a": ("x", [1, 2, 2, 3], expected_attrs), + "b": ("x", [3, -1, -2, 1], expected_attrs), + }, + {"x": ("x", [0, 1, 2, 3], expected_attrs)}, + ) + + assert_identical(actual, expected) + + @pytest.mark.parametrize( + "combine_attrs, attrs1, attrs2, expected_attrs, expect_exception", + [ + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 1, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + False, + ), + ("no_conflicts", {"a": 1, "b": 2}, {}, {"a": 1, "b": 2}, False), + ("no_conflicts", {}, {"a": 1, "c": 3}, {"a": 1, "c": 3}, False), + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 4, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + True, + ), + ("drop", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "b": 2}, {"a": 1, "b": 2}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {"a": 1, "b": 2}, True), + ( + "override", + {"a": 1, "b": 2}, + {"a": 4, "b": 5, "c": 3}, + {"a": 1, "b": 2}, + False, + ), + ( + "drop_conflicts", + {"a": 1, "b": 2, "c": 3}, + {"b": 1, "c": 3, "d": 4}, + {"a": 1, "c": 3, "d": 4}, + False, + ), + ], + ) + def test_combine_by_coords_combine_attrs_variables( + self, combine_attrs, attrs1, attrs2, expected_attrs, expect_exception + ): + """check that combine_attrs is used on data variables and coords""" + data1 = Dataset( + {"x": ("a", [0], attrs1), "y": ("a", [0], attrs1), "a": ("a", [0], attrs1)} + ) + data2 = Dataset( + {"x": ("a", [1], attrs2), "y": ("a", [1], attrs2), "a": ("a", [1], attrs2)} + ) + + if expect_exception: + with pytest.raises(MergeError, match="combine_attrs"): + combine_by_coords([data1, data2], combine_attrs=combine_attrs) + else: + actual = combine_by_coords([data1, data2], combine_attrs=combine_attrs) + expected = Dataset( + { + "x": ("a", [0, 1], expected_attrs), + "y": ("a", [0, 1], expected_attrs), + "a": ("a", [0, 1], expected_attrs), + } + ) + + assert_identical(actual, expected) + def test_infer_order_from_coords(self): data = create_test_data() objs = [data.isel(dim2=slice(4, 9)), data.isel(dim2=slice(4))] @@ -797,8 +1022,9 @@ def test_combine_by_coords_no_concat(self): def test_check_for_impossible_ordering(self): ds0 = Dataset({"x": [0, 1, 5]}) ds1 = Dataset({"x": [2, 3]}) - with raises_regex( - ValueError, "does not have monotonic global indexes along dimension x" + with pytest.raises( + ValueError, + match=r"does not have monotonic global indexes along dimension x", ): combine_by_coords([ds1, ds0]) @@ -818,6 +1044,22 @@ def test_combine_by_coords_incomplete_hypercube(self): with pytest.raises(ValueError): combine_by_coords([x1, x2, x3], fill_value=None) + def test_combine_by_coords_unnamed_arrays(self): + unnamed_array = DataArray(data=[1.0, 2.0], coords={"x": [0, 1]}, dims="x") + + actual = combine_by_coords([unnamed_array]) + expected = unnamed_array + assert_identical(expected, actual) + + unnamed_array1 = DataArray(data=[1.0, 2.0], coords={"x": [0, 1]}, dims="x") + unnamed_array2 = DataArray(data=[3.0, 4.0], coords={"x": [2, 3]}, dims="x") + + actual = combine_by_coords([unnamed_array1, unnamed_array2]) + expected = DataArray( + data=[1.0, 2.0, 3.0, 4.0], coords={"x": [0, 1, 2, 3]}, dims="x" + ) + assert_identical(expected, actual) + @requires_cftime def test_combine_by_coords_distant_cftime_dates(): @@ -854,5 +1096,22 @@ def test_combine_by_coords_raises_for_differing_calendars(): da_1 = DataArray([0], dims=["time"], coords=[time_1], name="a").to_dataset() da_2 = DataArray([1], dims=["time"], coords=[time_2], name="a").to_dataset() - with raises_regex(TypeError, r"cannot compare .* \(different calendars\)"): + if LooseVersion(cftime.__version__) >= LooseVersion("1.5"): + error_msg = "Cannot combine along dimension 'time' with mixed types." + else: + error_msg = r"cannot compare .* \(different calendars\)" + + with pytest.raises(TypeError, match=error_msg): + combine_by_coords([da_1, da_2]) + + +def test_combine_by_coords_raises_for_differing_types(): + + # str and byte cannot be compared + da_1 = DataArray([0], dims=["time"], coords=[["a"]], name="a").to_dataset() + da_2 = DataArray([1], dims=["time"], coords=[[b"b"]], name="a").to_dataset() + + with pytest.raises( + TypeError, match=r"Cannot combine along dimension 'time' with mixed types." + ): combine_by_coords([da_1, da_2]) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 4890536a5d7..2439ea30b4b 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1,7 +1,6 @@ import functools import operator import pickle -from distutils.version import LooseVersion import numpy as np import pandas as pd @@ -21,14 +20,16 @@ result_name, unified_dim_sizes, ) +from xarray.core.pycompat import dask_version -from . import has_dask, raises_regex, requires_dask +from . import has_dask, raise_if_dask_computes, requires_dask dask = pytest.importorskip("dask") def assert_identical(a, b): - """ A version of this function which accepts numpy arrays """ + """A version of this function which accepts numpy arrays""" + __tracebackhide__ = True from xarray.testing import assert_identical as assert_identical_ if hasattr(a, "identical"): @@ -563,14 +564,409 @@ def add(a, b, keep_attrs): assert_identical(actual.x.attrs, a.x.attrs) +@pytest.mark.parametrize( + ["strategy", "attrs", "expected", "error"], + ( + pytest.param( + None, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="default", + ), + pytest.param( + False, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="False", + ), + pytest.param( + True, + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="True", + ), + pytest.param( + "override", + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="override", + ), + pytest.param( + "drop", + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="drop", + ), + pytest.param( + "drop_conflicts", + [{"a": 1, "b": 2}, {"b": 1, "c": 3}, {"c": 3, "d": 4}], + {"a": 1, "c": 3, "d": 4}, + False, + id="drop_conflicts", + ), + pytest.param( + "no_conflicts", + [{"a": 1}, {"b": 2}, {"b": 3}], + None, + True, + id="no_conflicts", + ), + ), +) +def test_keep_attrs_strategies_variable(strategy, attrs, expected, error): + a = xr.Variable("x", [0, 1], attrs=attrs[0]) + b = xr.Variable("x", [0, 1], attrs=attrs[1]) + c = xr.Variable("x", [0, 1], attrs=attrs[2]) + + if error: + with pytest.raises(xr.MergeError): + apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + else: + expected = xr.Variable("x", [0, 3], attrs=expected) + actual = apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + + assert_identical(actual, expected) + + +@pytest.mark.parametrize( + ["strategy", "attrs", "expected", "error"], + ( + pytest.param( + None, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="default", + ), + pytest.param( + False, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="False", + ), + pytest.param( + True, + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="True", + ), + pytest.param( + "override", + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="override", + ), + pytest.param( + "drop", + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="drop", + ), + pytest.param( + "drop_conflicts", + [{"a": 1, "b": 2}, {"b": 1, "c": 3}, {"c": 3, "d": 4}], + {"a": 1, "c": 3, "d": 4}, + False, + id="drop_conflicts", + ), + pytest.param( + "no_conflicts", + [{"a": 1}, {"b": 2}, {"b": 3}], + None, + True, + id="no_conflicts", + ), + ), +) +def test_keep_attrs_strategies_dataarray(strategy, attrs, expected, error): + a = xr.DataArray(dims="x", data=[0, 1], attrs=attrs[0]) + b = xr.DataArray(dims="x", data=[0, 1], attrs=attrs[1]) + c = xr.DataArray(dims="x", data=[0, 1], attrs=attrs[2]) + + if error: + with pytest.raises(xr.MergeError): + apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + else: + expected = xr.DataArray(dims="x", data=[0, 3], attrs=expected) + actual = apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + + assert_identical(actual, expected) + + +@pytest.mark.parametrize("variant", ("dim", "coord")) +@pytest.mark.parametrize( + ["strategy", "attrs", "expected", "error"], + ( + pytest.param( + None, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="default", + ), + pytest.param( + False, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="False", + ), + pytest.param( + True, + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="True", + ), + pytest.param( + "override", + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="override", + ), + pytest.param( + "drop", + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="drop", + ), + pytest.param( + "drop_conflicts", + [{"a": 1, "b": 2}, {"b": 1, "c": 3}, {"c": 3, "d": 4}], + {"a": 1, "c": 3, "d": 4}, + False, + id="drop_conflicts", + ), + pytest.param( + "no_conflicts", + [{"a": 1}, {"b": 2}, {"b": 3}], + None, + True, + id="no_conflicts", + ), + ), +) +def test_keep_attrs_strategies_dataarray_variables( + variant, strategy, attrs, expected, error +): + compute_attrs = { + "dim": lambda attrs, default: (attrs, default), + "coord": lambda attrs, default: (default, attrs), + }.get(variant) + + dim_attrs, coord_attrs = compute_attrs(attrs, [{}, {}, {}]) + + a = xr.DataArray( + dims="x", + data=[0, 1], + coords={"x": ("x", [0, 1], dim_attrs[0]), "u": ("x", [0, 1], coord_attrs[0])}, + ) + b = xr.DataArray( + dims="x", + data=[0, 1], + coords={"x": ("x", [0, 1], dim_attrs[1]), "u": ("x", [0, 1], coord_attrs[1])}, + ) + c = xr.DataArray( + dims="x", + data=[0, 1], + coords={"x": ("x", [0, 1], dim_attrs[2]), "u": ("x", [0, 1], coord_attrs[2])}, + ) + + if error: + with pytest.raises(xr.MergeError): + apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + else: + dim_attrs, coord_attrs = compute_attrs(expected, {}) + expected = xr.DataArray( + dims="x", + data=[0, 3], + coords={"x": ("x", [0, 1], dim_attrs), "u": ("x", [0, 1], coord_attrs)}, + ) + actual = apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + + assert_identical(actual, expected) + + +@pytest.mark.parametrize( + ["strategy", "attrs", "expected", "error"], + ( + pytest.param( + None, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="default", + ), + pytest.param( + False, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="False", + ), + pytest.param( + True, + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="True", + ), + pytest.param( + "override", + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="override", + ), + pytest.param( + "drop", + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="drop", + ), + pytest.param( + "drop_conflicts", + [{"a": 1, "b": 2}, {"b": 1, "c": 3}, {"c": 3, "d": 4}], + {"a": 1, "c": 3, "d": 4}, + False, + id="drop_conflicts", + ), + pytest.param( + "no_conflicts", + [{"a": 1}, {"b": 2}, {"b": 3}], + None, + True, + id="no_conflicts", + ), + ), +) +def test_keep_attrs_strategies_dataset(strategy, attrs, expected, error): + a = xr.Dataset({"a": ("x", [0, 1])}, attrs=attrs[0]) + b = xr.Dataset({"a": ("x", [0, 1])}, attrs=attrs[1]) + c = xr.Dataset({"a": ("x", [0, 1])}, attrs=attrs[2]) + + if error: + with pytest.raises(xr.MergeError): + apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + else: + expected = xr.Dataset({"a": ("x", [0, 3])}, attrs=expected) + actual = apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + + assert_identical(actual, expected) + + +@pytest.mark.parametrize("variant", ("data", "dim", "coord")) +@pytest.mark.parametrize( + ["strategy", "attrs", "expected", "error"], + ( + pytest.param( + None, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="default", + ), + pytest.param( + False, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="False", + ), + pytest.param( + True, + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="True", + ), + pytest.param( + "override", + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="override", + ), + pytest.param( + "drop", + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="drop", + ), + pytest.param( + "drop_conflicts", + [{"a": 1, "b": 2}, {"b": 1, "c": 3}, {"c": 3, "d": 4}], + {"a": 1, "c": 3, "d": 4}, + False, + id="drop_conflicts", + ), + pytest.param( + "no_conflicts", + [{"a": 1}, {"b": 2}, {"b": 3}], + None, + True, + id="no_conflicts", + ), + ), +) +def test_keep_attrs_strategies_dataset_variables( + variant, strategy, attrs, expected, error +): + compute_attrs = { + "data": lambda attrs, default: (attrs, default, default), + "dim": lambda attrs, default: (default, attrs, default), + "coord": lambda attrs, default: (default, default, attrs), + }.get(variant) + data_attrs, dim_attrs, coord_attrs = compute_attrs(attrs, [{}, {}, {}]) + + a = xr.Dataset( + {"a": ("x", [], data_attrs[0])}, + coords={"x": ("x", [], dim_attrs[0]), "u": ("x", [], coord_attrs[0])}, + ) + b = xr.Dataset( + {"a": ("x", [], data_attrs[1])}, + coords={"x": ("x", [], dim_attrs[1]), "u": ("x", [], coord_attrs[1])}, + ) + c = xr.Dataset( + {"a": ("x", [], data_attrs[2])}, + coords={"x": ("x", [], dim_attrs[2]), "u": ("x", [], coord_attrs[2])}, + ) + + if error: + with pytest.raises(xr.MergeError): + apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + else: + data_attrs, dim_attrs, coord_attrs = compute_attrs(expected, {}) + expected = xr.Dataset( + {"a": ("x", [], data_attrs)}, + coords={"x": ("x", [], dim_attrs), "u": ("x", [], coord_attrs)}, + ) + actual = apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + + assert_identical(actual, expected) + + def test_dataset_join(): ds0 = xr.Dataset({"a": ("x", [1, 2]), "x": [0, 1]}) ds1 = xr.Dataset({"a": ("x", [99, 3]), "x": [1, 2]}) # by default, cannot have different labels - with raises_regex(ValueError, "indexes .* are not equal"): + with pytest.raises(ValueError, match=r"indexes .* are not equal"): apply_ufunc(operator.add, ds0, ds1) - with raises_regex(TypeError, "must supply"): + with pytest.raises(TypeError, match=r"must supply"): apply_ufunc(operator.add, ds0, ds1, dataset_join="outer") def add(a, b, join, dataset_join): @@ -590,7 +986,7 @@ def add(a, b, join, dataset_join): actual = add(ds0, ds1, "outer", "outer") assert_identical(actual, expected) - with raises_regex(ValueError, "data variable names"): + with pytest.raises(ValueError, match=r"data variable names"): apply_ufunc(operator.add, ds0, xr.Dataset({"b": 1})) ds2 = xr.Dataset({"b": ("x", [99, 3]), "x": [1, 2]}) @@ -708,11 +1104,11 @@ def test_apply_dask_parallelized_errors(): data_array = xr.DataArray(array, dims=("x", "y")) # from apply_array_ufunc - with raises_regex(ValueError, "at least one input is an xarray object"): + with pytest.raises(ValueError, match=r"at least one input is an xarray object"): apply_ufunc(identity, array, dask="parallelized") # formerly from _apply_blockwise, now from apply_variable_ufunc - with raises_regex(ValueError, "consists of multiple chunks"): + with pytest.raises(ValueError, match=r"consists of multiple chunks"): apply_ufunc( identity, data_array, @@ -910,7 +1306,10 @@ def test_vectorize_dask_dtype_without_output_dtypes(data_array): assert expected.dtype == actual.dtype -@pytest.mark.xfail(LooseVersion(dask.__version__) < "2.3", reason="dask GH5274") +@pytest.mark.skipif( + dask_version > "2021.06", + reason="dask/dask#7669: can no longer pass output_dtypes and meta", +) @requires_dask def test_vectorize_dask_dtype_meta(): # meta dtype takes precedence @@ -990,6 +1389,7 @@ def arrays_w_tuples(): da.isel(time=range(2, 20)).rolling(time=3, center=True).mean(), xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"]), xr.DataArray([[1, 2], [np.nan, np.nan]], dims=["x", "time"]), + xr.DataArray([[1, 2], [2, 1]], dims=["x", "time"]), ] array_tuples = [ @@ -998,12 +1398,40 @@ def arrays_w_tuples(): (arrays[1], arrays[1]), (arrays[2], arrays[2]), (arrays[2], arrays[3]), + (arrays[2], arrays[4]), + (arrays[4], arrays[2]), (arrays[3], arrays[3]), + (arrays[4], arrays[4]), ] return arrays, array_tuples +@pytest.mark.parametrize("ddof", [0, 1]) +@pytest.mark.parametrize( + "da_a, da_b", + [ + arrays_w_tuples()[1][3], + arrays_w_tuples()[1][4], + arrays_w_tuples()[1][5], + arrays_w_tuples()[1][6], + arrays_w_tuples()[1][7], + arrays_w_tuples()[1][8], + ], +) +@pytest.mark.parametrize("dim", [None, "x", "time"]) +def test_lazy_corrcov(da_a, da_b, dim, ddof): + # GH 5284 + from dask import is_dask_collection + + with raise_if_dask_computes(): + cov = xr.cov(da_a.chunk(), da_b.chunk(), dim=dim, ddof=ddof) + assert is_dask_collection(cov) + + corr = xr.corr(da_a.chunk(), da_b.chunk(), dim=dim) + assert is_dask_collection(corr) + + @pytest.mark.parametrize("ddof", [0, 1]) @pytest.mark.parametrize( "da_a, da_b", @@ -1160,7 +1588,9 @@ def test_vectorize_dask_new_output_dims(): ).transpose(*expected.dims) assert_identical(expected, actual) - with raises_regex(ValueError, "dimension 'z1' in 'output_sizes' must correspond"): + with pytest.raises( + ValueError, match=r"dimension 'z1' in 'output_sizes' must correspond" + ): apply_ufunc( func, data_array.chunk({"x": 1}), @@ -1171,8 +1601,8 @@ def test_vectorize_dask_new_output_dims(): dask_gufunc_kwargs=dict(output_sizes={"z1": 1}), ) - with raises_regex( - ValueError, "dimension 'z' in 'output_core_dims' needs corresponding" + with pytest.raises( + ValueError, match=r"dimension 'z' in 'output_core_dims' needs corresponding" ): apply_ufunc( func, @@ -1193,10 +1623,10 @@ def identity(x): def tuple3x(x): return (x, x, x) - with raises_regex(ValueError, "number of outputs"): + with pytest.raises(ValueError, match=r"number of outputs"): apply_ufunc(identity, variable, output_core_dims=[(), ()]) - with raises_regex(ValueError, "number of outputs"): + with pytest.raises(ValueError, match=r"number of outputs"): apply_ufunc(tuple3x, variable, output_core_dims=[(), ()]) @@ -1209,13 +1639,13 @@ def add_dim(x): def remove_dim(x): return x[..., 0] - with raises_regex(ValueError, "unexpected number of dimensions"): + with pytest.raises(ValueError, match=r"unexpected number of dimensions"): apply_ufunc(add_dim, variable, output_core_dims=[("y", "z")]) - with raises_regex(ValueError, "unexpected number of dimensions"): + with pytest.raises(ValueError, match=r"unexpected number of dimensions"): apply_ufunc(add_dim, variable) - with raises_regex(ValueError, "unexpected number of dimensions"): + with pytest.raises(ValueError, match=r"unexpected number of dimensions"): apply_ufunc(remove_dim, variable) @@ -1231,11 +1661,11 @@ def truncate(array): def apply_truncate_broadcast_invalid(obj): return apply_ufunc(truncate, obj) - with raises_regex(ValueError, "size of dimension"): + with pytest.raises(ValueError, match=r"size of dimension"): apply_truncate_broadcast_invalid(variable) - with raises_regex(ValueError, "size of dimension"): + with pytest.raises(ValueError, match=r"size of dimension"): apply_truncate_broadcast_invalid(data_array) - with raises_regex(ValueError, "size of dimension"): + with pytest.raises(ValueError, match=r"size of dimension"): apply_truncate_broadcast_invalid(dataset) def apply_truncate_x_x_invalid(obj): @@ -1243,11 +1673,11 @@ def apply_truncate_x_x_invalid(obj): truncate, obj, input_core_dims=[["x"]], output_core_dims=[["x"]] ) - with raises_regex(ValueError, "size of dimension"): + with pytest.raises(ValueError, match=r"size of dimension"): apply_truncate_x_x_invalid(variable) - with raises_regex(ValueError, "size of dimension"): + with pytest.raises(ValueError, match=r"size of dimension"): apply_truncate_x_x_invalid(data_array) - with raises_regex(ValueError, "size of dimension"): + with pytest.raises(ValueError, match=r"size of dimension"): apply_truncate_x_x_invalid(dataset) def apply_truncate_x_z(obj): @@ -1442,7 +1872,7 @@ def test_dot_align_coords(use_dask): xr.testing.assert_allclose(expected, actual) with xr.set_options(arithmetic_join="exact"): - with raises_regex(ValueError, "indexes along dimension"): + with pytest.raises(ValueError, match=r"indexes along dimension"): xr.dot(da_a, da_b) # NOTE: dot always uses `join="inner"` because `(a * b).sum()` yields the same for all diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 7416cab13ed..36ef0237b27 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -12,7 +12,6 @@ assert_array_equal, assert_equal, assert_identical, - raises_regex, requires_dask, ) from .test_dataset import create_test_data @@ -41,9 +40,11 @@ def test_concat_compat(): for var in ["has_x", "no_x_y"]: assert "y" not in result[var].dims and "y" not in result[var].coords - with raises_regex(ValueError, "coordinates in some datasets but not others"): + with pytest.raises( + ValueError, match=r"coordinates in some datasets but not others" + ): concat([ds1, ds2], dim="q") - with raises_regex(ValueError, "'q' is not present in all datasets"): + with pytest.raises(ValueError, match=r"'q' is not present in all datasets"): concat([ds2, ds1], dim="q") @@ -141,7 +142,7 @@ def test_concat_coords(self): actual = concat(objs, dim="x", coords=coords) assert_identical(expected, actual) for coords in ["minimal", []]: - with raises_regex(merge.MergeError, "conflicting values"): + with pytest.raises(merge.MergeError, match="conflicting values"): concat(objs, dim="x", coords=coords) def test_concat_constant_index(self): @@ -152,7 +153,7 @@ def test_concat_constant_index(self): for mode in ["different", "all", ["foo"]]: actual = concat([ds1, ds2], "y", data_vars=mode) assert_identical(expected, actual) - with raises_regex(merge.MergeError, "conflicting values"): + with pytest.raises(merge.MergeError, match="conflicting values"): # previously dim="y", and raised error which makes no sense. # "foo" has dimension "y" so minimal should concatenate it? concat([ds1, ds2], "new_dim", data_vars="minimal") @@ -185,36 +186,40 @@ def test_concat_errors(self): data = create_test_data() split_data = [data.isel(dim1=slice(3)), data.isel(dim1=slice(3, None))] - with raises_regex(ValueError, "must supply at least one"): + with pytest.raises(ValueError, match=r"must supply at least one"): concat([], "dim1") - with raises_regex(ValueError, "Cannot specify both .*='different'"): + with pytest.raises(ValueError, match=r"Cannot specify both .*='different'"): concat( [data, data], dim="concat_dim", data_vars="different", compat="override" ) - with raises_regex(ValueError, "must supply at least one"): + with pytest.raises(ValueError, match=r"must supply at least one"): concat([], "dim1") - with raises_regex(ValueError, "are not coordinates"): + with pytest.raises(ValueError, match=r"are not coordinates"): concat([data, data], "new_dim", coords=["not_found"]) - with raises_regex(ValueError, "global attributes not"): + with pytest.raises(ValueError, match=r"global attributes not"): data0, data1 = deepcopy(split_data) data1.attrs["foo"] = "bar" concat([data0, data1], "dim1", compat="identical") assert_identical(data, concat([data0, data1], "dim1", compat="equals")) - with raises_regex(ValueError, "compat.* invalid"): + with pytest.raises(ValueError, match=r"compat.* invalid"): concat(split_data, "dim1", compat="foobar") - with raises_regex(ValueError, "unexpected value for"): + with pytest.raises(ValueError, match=r"unexpected value for"): concat([data, data], "new_dim", coords="foobar") - with raises_regex(ValueError, "coordinate in some datasets but not others"): + with pytest.raises( + ValueError, match=r"coordinate in some datasets but not others" + ): concat([Dataset({"x": 0}), Dataset({"x": [1]})], dim="z") - with raises_regex(ValueError, "coordinate in some datasets but not others"): + with pytest.raises( + ValueError, match=r"coordinate in some datasets but not others" + ): concat([Dataset({"x": 0}), Dataset({}, {"x": 1})], dim="z") def test_concat_join_kwarg(self): @@ -242,7 +247,7 @@ def test_concat_join_kwarg(self): coords={"x": [0, 1], "y": [0]}, ) - with raises_regex(ValueError, "indexes along dimension 'y'"): + with pytest.raises(ValueError, match=r"indexes along dimension 'y'"): actual = concat([ds1, ds2], join="exact", dim="x") for join in expected: @@ -258,27 +263,131 @@ def test_concat_join_kwarg(self): ) assert_identical(actual, expected) - def test_concat_combine_attrs_kwarg(self): - ds1 = Dataset({"a": ("x", [0])}, coords={"x": [0]}, attrs={"b": 42}) - ds2 = Dataset({"a": ("x", [0])}, coords={"x": [1]}, attrs={"b": 42, "c": 43}) - - expected = {} - expected["drop"] = Dataset({"a": ("x", [0, 0])}, {"x": [0, 1]}) - expected["no_conflicts"] = Dataset( - {"a": ("x", [0, 0])}, {"x": [0, 1]}, {"b": 42, "c": 43} - ) - expected["override"] = Dataset({"a": ("x", [0, 0])}, {"x": [0, 1]}, {"b": 42}) - - with raises_regex(ValueError, "combine_attrs='identical'"): - actual = concat([ds1, ds2], dim="x", combine_attrs="identical") - with raises_regex(ValueError, "combine_attrs='no_conflicts'"): - ds3 = ds2.copy(deep=True) - ds3.attrs["b"] = 44 - actual = concat([ds1, ds3], dim="x", combine_attrs="no_conflicts") + @pytest.mark.parametrize( + "combine_attrs, var1_attrs, var2_attrs, expected_attrs, expect_exception", + [ + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 1, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + False, + ), + ("no_conflicts", {"a": 1, "b": 2}, {}, {"a": 1, "b": 2}, False), + ("no_conflicts", {}, {"a": 1, "c": 3}, {"a": 1, "c": 3}, False), + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 4, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + True, + ), + ("drop", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "b": 2}, {"a": 1, "b": 2}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {"a": 1, "b": 2}, True), + ( + "override", + {"a": 1, "b": 2}, + {"a": 4, "b": 5, "c": 3}, + {"a": 1, "b": 2}, + False, + ), + ( + "drop_conflicts", + {"a": 41, "b": 42, "c": 43}, + {"b": 2, "c": 43, "d": 44}, + {"a": 41, "c": 43, "d": 44}, + False, + ), + ( + lambda attrs, context: {"a": -1, "b": 0, "c": 1} if any(attrs) else {}, + {"a": 41, "b": 42, "c": 43}, + {"b": 2, "c": 43, "d": 44}, + {"a": -1, "b": 0, "c": 1}, + False, + ), + ], + ) + def test_concat_combine_attrs_kwarg( + self, combine_attrs, var1_attrs, var2_attrs, expected_attrs, expect_exception + ): + ds1 = Dataset({"a": ("x", [0])}, coords={"x": [0]}, attrs=var1_attrs) + ds2 = Dataset({"a": ("x", [0])}, coords={"x": [1]}, attrs=var2_attrs) + + if expect_exception: + with pytest.raises(ValueError, match=f"combine_attrs='{combine_attrs}'"): + concat([ds1, ds2], dim="x", combine_attrs=combine_attrs) + else: + actual = concat([ds1, ds2], dim="x", combine_attrs=combine_attrs) + expected = Dataset( + {"a": ("x", [0, 0])}, {"x": [0, 1]}, attrs=expected_attrs + ) - for combine_attrs in expected: + assert_identical(actual, expected) + + @pytest.mark.parametrize( + "combine_attrs, attrs1, attrs2, expected_attrs, expect_exception", + [ + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 1, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + False, + ), + ("no_conflicts", {"a": 1, "b": 2}, {}, {"a": 1, "b": 2}, False), + ("no_conflicts", {}, {"a": 1, "c": 3}, {"a": 1, "c": 3}, False), + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 4, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + True, + ), + ("drop", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "b": 2}, {"a": 1, "b": 2}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {"a": 1, "b": 2}, True), + ( + "override", + {"a": 1, "b": 2}, + {"a": 4, "b": 5, "c": 3}, + {"a": 1, "b": 2}, + False, + ), + ( + "drop_conflicts", + {"a": 41, "b": 42, "c": 43}, + {"b": 2, "c": 43, "d": 44}, + {"a": 41, "c": 43, "d": 44}, + False, + ), + ( + lambda attrs, context: {"a": -1, "b": 0, "c": 1} if any(attrs) else {}, + {"a": 41, "b": 42, "c": 43}, + {"b": 2, "c": 43, "d": 44}, + {"a": -1, "b": 0, "c": 1}, + False, + ), + ], + ) + def test_concat_combine_attrs_kwarg_variables( + self, combine_attrs, attrs1, attrs2, expected_attrs, expect_exception + ): + """check that combine_attrs is used on data variables and coords""" + ds1 = Dataset({"a": ("x", [0], attrs1)}, coords={"x": ("x", [0], attrs1)}) + ds2 = Dataset({"a": ("x", [0], attrs2)}, coords={"x": ("x", [1], attrs2)}) + + if expect_exception: + with pytest.raises(ValueError, match=f"combine_attrs='{combine_attrs}'"): + concat([ds1, ds2], dim="x", combine_attrs=combine_attrs) + else: actual = concat([ds1, ds2], dim="x", combine_attrs=combine_attrs) - assert_identical(actual, expected[combine_attrs]) + expected = Dataset( + {"a": ("x", [0, 0], expected_attrs)}, + {"x": ("x", [0, 1], expected_attrs)}, + ) + + assert_identical(actual, expected) def test_concat_promote_shape(self): # mixed dims within variables @@ -426,7 +535,7 @@ def test_concat(self): stacked = concat(grouped, ds["x"]) assert_identical(foo, stacked) # with an index as the 'dim' argument - stacked = concat(grouped, ds.indexes["x"]) + stacked = concat(grouped, pd.Index(ds["x"], name="x")) assert_identical(foo, stacked) actual = concat([foo[0], foo[1]], pd.Index([0, 1])).reset_coords(drop=True) @@ -437,10 +546,10 @@ def test_concat(self): expected = foo[:2].rename({"x": "concat_dim"}) assert_identical(expected, actual) - with raises_regex(ValueError, "not identical"): + with pytest.raises(ValueError, match=r"not identical"): concat([foo, bar], dim="w", compat="identical") - with raises_regex(ValueError, "not a valid argument"): + with pytest.raises(ValueError, match=r"not a valid argument"): concat([foo, bar], dim="w", data_vars="minimal") def test_concat_encoding(self): @@ -518,7 +627,7 @@ def test_concat_join_kwarg(self): coords={"x": [0, 1], "y": [0]}, ) - with raises_regex(ValueError, "indexes along dimension 'y'"): + with pytest.raises(ValueError, match=r"indexes along dimension 'y'"): actual = concat([ds1, ds2], join="exact", dim="x") for join in expected: @@ -538,9 +647,9 @@ def test_concat_combine_attrs_kwarg(self): [0, 0], coords=[("x", [0, 1])], attrs={"b": 42} ) - with raises_regex(ValueError, "combine_attrs='identical'"): + with pytest.raises(ValueError, match=r"combine_attrs='identical'"): actual = concat([da1, da2], dim="x", combine_attrs="identical") - with raises_regex(ValueError, "combine_attrs='no_conflicts'"): + with pytest.raises(ValueError, match=r"combine_attrs='no_conflicts'"): da3 = da2.copy(deep=True) da3.attrs["b"] = 44 actual = concat([da1, da3], dim="x", combine_attrs="no_conflicts") @@ -593,14 +702,14 @@ def test_concat_merge_single_non_dim_coord(): actual = concat([da1, da2], "x", coords=coords) assert_identical(actual, expected) - with raises_regex(ValueError, "'y' is not present in all datasets."): + with pytest.raises(ValueError, match=r"'y' is not present in all datasets."): concat([da1, da2], dim="x", coords="all") da1 = DataArray([1, 2, 3], dims="x", coords={"x": [1, 2, 3], "y": 1}) da2 = DataArray([4, 5, 6], dims="x", coords={"x": [4, 5, 6]}) da3 = DataArray([7, 8, 9], dims="x", coords={"x": [7, 8, 9], "y": 1}) for coords in ["different", "all"]: - with raises_regex(ValueError, "'y' not present in all datasets"): + with pytest.raises(ValueError, match=r"'y' not present in all datasets"): concat([da1, da2, da3], dim="x") @@ -635,3 +744,20 @@ def test_concat_preserve_coordinate_order(): for act, exp in zip(actual.coords, expected.coords): assert act == exp assert_identical(actual.coords[act], expected.coords[exp]) + + +def test_concat_typing_check(): + ds = Dataset({"foo": 1}, {"bar": 2}) + da = Dataset({"foo": 3}, {"bar": 4}).to_array(dim="foo") + + # concatenate a list of non-homogeneous types must raise TypeError + with pytest.raises( + TypeError, + match="The elements in the input list need to be either all 'Dataset's or all 'DataArray's", + ): + concat([ds, da], dim="foo") + with pytest.raises( + TypeError, + match="The elements in the input list need to be either all 'Dataset's or all 'DataArray's", + ): + concat([da, ds], dim="foo") diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index 9abaa978651..ceea167719f 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -18,13 +18,7 @@ from xarray.conventions import decode_cf from xarray.testing import assert_identical -from . import ( - assert_array_equal, - raises_regex, - requires_cftime, - requires_dask, - requires_netCDF4, -) +from . import assert_array_equal, requires_cftime, requires_dask, requires_netCDF4 from .test_backends import CFEncodedBase @@ -145,9 +139,43 @@ def test_do_not_overwrite_user_coordinates(self): assert enc["a"].attrs["coordinates"] == "y" assert enc["b"].attrs["coordinates"] == "z" orig["a"].attrs["coordinates"] = "foo" - with raises_regex(ValueError, "'coordinates' found in both attrs"): + with pytest.raises(ValueError, match=r"'coordinates' found in both attrs"): conventions.encode_dataset_coordinates(orig) + def test_emit_coordinates_attribute_in_attrs(self): + orig = Dataset( + {"a": 1, "b": 1}, + coords={"t": np.array("2004-11-01T00:00:00", dtype=np.datetime64)}, + ) + + orig["a"].attrs["coordinates"] = None + enc, _ = conventions.encode_dataset_coordinates(orig) + + # check coordinate attribute emitted for 'a' + assert "coordinates" not in enc["a"].attrs + assert "coordinates" not in enc["a"].encoding + + # check coordinate attribute not emitted for 'b' + assert enc["b"].attrs.get("coordinates") == "t" + assert "coordinates" not in enc["b"].encoding + + def test_emit_coordinates_attribute_in_encoding(self): + orig = Dataset( + {"a": 1, "b": 1}, + coords={"t": np.array("2004-11-01T00:00:00", dtype=np.datetime64)}, + ) + + orig["a"].encoding["coordinates"] = None + enc, _ = conventions.encode_dataset_coordinates(orig) + + # check coordinate attribute emitted for 'a' + assert "coordinates" not in enc["a"].attrs + assert "coordinates" not in enc["a"].encoding + + # check coordinate attribute not emitted for 'b' + assert enc["b"].attrs.get("coordinates") == "t" + assert "coordinates" not in enc["b"].encoding + @requires_dask def test_string_object_warning(self): original = Variable(("x",), np.array(["foo", "bar"], dtype=object)).chunk() @@ -236,7 +264,7 @@ def test_decode_cf_with_drop_variables(self): @pytest.mark.filterwarnings("ignore:Ambiguous reference date string") def test_invalid_time_units_raises_eagerly(self): ds = Dataset({"time": ("time", [0, 1], {"units": "foobar since 123"})}) - with raises_regex(ValueError, "unable to decode time"): + with pytest.raises(ValueError, match=r"unable to decode time"): decode_cf(ds) @requires_cftime @@ -286,7 +314,7 @@ def test_decode_cf_with_dask(self): assert all( isinstance(var.data, da.Array) for name, var in decoded.variables.items() - if name not in decoded.indexes + if name not in decoded.xindexes ) assert_identical(decoded, conventions.decode_cf(original).compute()) diff --git a/xarray/tests/test_cupy.py b/xarray/tests/test_cupy.py index 0276b8ebc08..69f43d99139 100644 --- a/xarray/tests/test_cupy.py +++ b/xarray/tests/test_cupy.py @@ -47,7 +47,7 @@ def test_cupy_import(): def test_check_data_stays_on_gpu(toy_weather_data): """Perform some operations and check the data stays on the GPU.""" freeze = (toy_weather_data["tmin"] <= 0).groupby("time.month").mean("time") - assert isinstance(freeze.data, cp.core.core.ndarray) + assert isinstance(freeze.data, cp.ndarray) def test_where(): diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 19a61c60577..d5d460056aa 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -2,7 +2,6 @@ import pickle import sys from contextlib import suppress -from distutils.version import LooseVersion from textwrap import dedent import numpy as np @@ -13,6 +12,7 @@ import xarray.ufuncs as xu from xarray import DataArray, Dataset, Variable from xarray.core import duck_array_ops +from xarray.core.pycompat import dask_version from xarray.testing import assert_chunks_equal from xarray.tests import mock @@ -24,7 +24,6 @@ assert_frame_equal, assert_identical, raise_if_dask_computes, - raises_regex, requires_pint_0_15, requires_scipy_or_netCDF4, ) @@ -39,7 +38,7 @@ def test_raise_if_dask_computes(): data = da.from_array(np.random.RandomState(0).randn(4, 6), chunks=(2, 2)) - with raises_regex(RuntimeError, "Too many computes"): + with pytest.raises(RuntimeError, match=r"Too many computes"): with raise_if_dask_computes(): data.compute() @@ -95,7 +94,7 @@ def test_copy(self): def test_chunk(self): for chunks, expected in [ - (None, ((2, 2), (2, 2, 2))), + ({}, ((2, 2), (2, 2, 2))), (3, ((3, 1), (3, 3))), ({"x": 3, "y": 3}, ((3, 1), (3, 3))), ({"x": 3}, ((3, 1), (2, 2, 2))), @@ -111,7 +110,30 @@ def test_indexing(self): self.assertLazyAndIdentical(u[0], v[0]) self.assertLazyAndIdentical(u[:1], v[:1]) self.assertLazyAndIdentical(u[[0, 1], [0, 1, 2]], v[[0, 1], [0, 1, 2]]) - with raises_regex(TypeError, "stored in a dask array"): + + @pytest.mark.skipif(dask_version < "2021.04.1", reason="Requires dask >= 2021.04.1") + @pytest.mark.parametrize( + "expected_data, index", + [ + (da.array([99, 2, 3, 4]), 0), + (da.array([99, 99, 99, 4]), slice(2, None, -1)), + (da.array([99, 99, 3, 99]), [0, -1, 1]), + (da.array([99, 99, 99, 4]), np.arange(3)), + (da.array([1, 99, 99, 99]), [False, True, True, True]), + (da.array([1, 99, 99, 99]), np.arange(4) > 0), + (da.array([99, 99, 99, 99]), Variable(("x"), da.array([1, 2, 3, 4])) > 0), + ], + ) + def test_setitem_dask_array(self, expected_data, index): + arr = Variable(("x"), da.array([1, 2, 3, 4])) + expected = Variable(("x"), expected_data) + arr[index] = 99 + assert_identical(arr, expected) + + @pytest.mark.skipif(dask_version >= "2021.04.1", reason="Requires dask < 2021.04.1") + def test_setitem_dask_array_error(self): + with pytest.raises(TypeError, match=r"stored in a dask array"): + v = self.lazy_var v[:1] = 0 def test_squeeze(self): @@ -194,9 +216,9 @@ def test_reduce(self): self.assertLazyAndAllClose(u.argmin(dim="x"), actual) self.assertLazyAndAllClose((u > 1).any(), (v > 1).any()) self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x")) - with raises_regex(NotImplementedError, "only works along an axis"): + with pytest.raises(NotImplementedError, match=r"only works along an axis"): v.median() - with raises_regex(NotImplementedError, "only works along an axis"): + with pytest.raises(NotImplementedError, match=r"only works along an axis"): v.median(v.dims) with raise_if_dask_computes(): v.reduce(duck_array_ops.mean) @@ -506,7 +528,7 @@ def test_groupby_first(self): for coords in [u.coords, v.coords]: coords["ab"] = ("x", ["a", "a", "b", "b"]) - with raises_regex(NotImplementedError, "dask"): + with pytest.raises(NotImplementedError, match=r"dask"): v.groupby("ab").first() expected = u.groupby("ab").first() with raise_if_dask_computes(): @@ -584,25 +606,6 @@ def test_dot(self): lazy = self.lazy_array.dot(self.lazy_array[0]) self.assertLazyAndAllClose(eager, lazy) - @pytest.mark.skipif(LooseVersion(dask.__version__) >= "2.0", reason="no meta") - def test_dataarray_repr_legacy(self): - data = build_dask_array("data") - nonindex_coord = build_dask_array("coord") - a = DataArray(data, dims=["x"], coords={"y": ("x", nonindex_coord)}) - expected = dedent( - """\ - - {!r} - Coordinates: - y (x) int64 dask.array - Dimensions without coordinates: x""".format( - data - ) - ) - assert expected == repr(a) - assert kernel_call_count == 0 # should not evaluate dask array - - @pytest.mark.skipif(LooseVersion(dask.__version__) < "2.0", reason="needs meta") def test_dataarray_repr(self): data = build_dask_array("data") nonindex_coord = build_dask_array("coord") @@ -620,7 +623,6 @@ def test_dataarray_repr(self): assert expected == repr(a) assert kernel_call_count == 0 # should not evaluate dask array - @pytest.mark.skipif(LooseVersion(dask.__version__) < "2.0", reason="needs meta") def test_dataset_repr(self): data = build_dask_array("data") nonindex_coord = build_dask_array("coord") @@ -851,7 +853,7 @@ def test_to_dask_dataframe_dim_order(self): assert isinstance(actual, dd.DataFrame) assert_frame_equal(expected, actual.compute()) - with raises_regex(ValueError, "does not match the set of dimensions"): + with pytest.raises(ValueError, match=r"does not match the set of dimensions"): ds.to_dask_dataframe(dim_order=["x"]) @@ -1038,15 +1040,29 @@ def test_unify_chunks(map_ds): ds_copy = map_ds.copy() ds_copy["cxy"] = ds_copy.cxy.chunk({"y": 10}) - with raises_regex(ValueError, "inconsistent chunks"): + with pytest.raises(ValueError, match=r"inconsistent chunks"): ds_copy.chunks - expected_chunks = {"x": (4, 4, 2), "y": (5, 5, 5, 5), "z": (4,)} + expected_chunks = {"x": (4, 4, 2), "y": (5, 5, 5, 5)} with raise_if_dask_computes(): actual_chunks = ds_copy.unify_chunks().chunks - expected_chunks == actual_chunks + assert actual_chunks == expected_chunks assert_identical(map_ds, ds_copy.unify_chunks()) + out_a, out_b = xr.unify_chunks(ds_copy.cxy, ds_copy.drop_vars("cxy")) + assert out_a.chunks == ((4, 4, 2), (5, 5, 5, 5)) + assert out_b.chunks == expected_chunks + + # Test unordered dims + da = ds_copy["cxy"] + out_a, out_b = xr.unify_chunks(da.chunk({"x": -1}), da.T.chunk({"y": -1})) + assert out_a.chunks == ((4, 4, 2), (5, 5, 5, 5)) + assert out_b.chunks == ((5, 5, 5, 5), (4, 4, 2)) + + # Test mismatch + with pytest.raises(ValueError, match=r"Dimension 'x' size mismatch: 10 != 2"): + xr.unify_chunks(da, da.isel(x=slice(2))) + @pytest.mark.parametrize("obj", [make_ds(), make_da()]) @pytest.mark.parametrize( @@ -1070,34 +1086,34 @@ def test_map_blocks_error(map_da, map_ds): def bad_func(darray): return (darray * darray.x + 5 * darray.y)[:1, :1] - with raises_regex(ValueError, "Received dimension 'x' of length 1"): + with pytest.raises(ValueError, match=r"Received dimension 'x' of length 1"): xr.map_blocks(bad_func, map_da).compute() def returns_numpy(darray): return (darray * darray.x + 5 * darray.y).values - with raises_regex(TypeError, "Function must return an xarray DataArray"): + with pytest.raises(TypeError, match=r"Function must return an xarray DataArray"): xr.map_blocks(returns_numpy, map_da) - with raises_regex(TypeError, "args must be"): + with pytest.raises(TypeError, match=r"args must be"): xr.map_blocks(operator.add, map_da, args=10) - with raises_regex(TypeError, "kwargs must be"): + with pytest.raises(TypeError, match=r"kwargs must be"): xr.map_blocks(operator.add, map_da, args=[10], kwargs=[20]) def really_bad_func(darray): raise ValueError("couldn't do anything.") - with raises_regex(Exception, "Cannot infer"): + with pytest.raises(Exception, match=r"Cannot infer"): xr.map_blocks(really_bad_func, map_da) ds_copy = map_ds.copy() ds_copy["cxy"] = ds_copy.cxy.chunk({"y": 10}) - with raises_regex(ValueError, "inconsistent chunks"): + with pytest.raises(ValueError, match=r"inconsistent chunks"): xr.map_blocks(bad_func, ds_copy) - with raises_regex(TypeError, "Cannot pass dask collections"): + with pytest.raises(TypeError, match=r"Cannot pass dask collections"): xr.map_blocks(bad_func, map_da, kwargs=dict(a=map_da.chunk())) @@ -1152,10 +1168,10 @@ def sumda(da1, da2): mapped = xr.map_blocks(operator.add, da1, args=[da2]) xr.testing.assert_equal(da1 + da2, mapped) - with raises_regex(ValueError, "Chunk sizes along dimension 'x'"): + with pytest.raises(ValueError, match=r"Chunk sizes along dimension 'x'"): xr.map_blocks(operator.add, da1, args=[da1.chunk({"x": 1})]) - with raises_regex(ValueError, "indexes along dimension 'x' are not equal"): + with pytest.raises(ValueError, match=r"indexes along dimension 'x' are not equal"): xr.map_blocks(operator.add, da1, args=[da1.reindex(x=np.arange(20))]) # reduction @@ -1233,7 +1249,7 @@ def test_map_blocks_to_array(map_ds): lambda x: x.drop_vars("x"), lambda x: x.expand_dims(k=[1, 2, 3]), lambda x: x.expand_dims(k=3), - lambda x: x.assign_coords(new_coord=("y", x.y * 2)), + lambda x: x.assign_coords(new_coord=("y", x.y.data * 2)), lambda x: x.astype(np.int32), lambda x: x.x, ], @@ -1296,21 +1312,21 @@ def test_map_blocks_template_convert_object(): @pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks_errors_bad_template(obj): - with raises_regex(ValueError, "unexpected coordinate variables"): + with pytest.raises(ValueError, match=r"unexpected coordinate variables"): xr.map_blocks(lambda x: x.assign_coords(a=10), obj, template=obj).compute() - with raises_regex(ValueError, "does not contain coordinate variables"): + with pytest.raises(ValueError, match=r"does not contain coordinate variables"): xr.map_blocks(lambda x: x.drop_vars("cxy"), obj, template=obj).compute() - with raises_regex(ValueError, "Dimensions {'x'} missing"): + with pytest.raises(ValueError, match=r"Dimensions {'x'} missing"): xr.map_blocks(lambda x: x.isel(x=1), obj, template=obj).compute() - with raises_regex(ValueError, "Received dimension 'x' of length 1"): + with pytest.raises(ValueError, match=r"Received dimension 'x' of length 1"): xr.map_blocks(lambda x: x.isel(x=[1]), obj, template=obj).compute() - with raises_regex(TypeError, "must be a DataArray"): + with pytest.raises(TypeError, match=r"must be a DataArray"): xr.map_blocks(lambda x: x.isel(x=[1]), obj, template=(obj,)).compute() - with raises_regex(ValueError, "map_blocks requires that one block"): + with pytest.raises(ValueError, match=r"map_blocks requires that one block"): xr.map_blocks( lambda x: x.isel(x=[1]).assign_coords(x=10), obj, template=obj.isel(x=[1]) ).compute() - with raises_regex(ValueError, "Expected index 'x' to be"): + with pytest.raises(ValueError, match=r"Expected index 'x' to be"): xr.map_blocks( lambda a: a.isel(x=[1]).assign_coords(x=[120]), # assign bad index values obj, @@ -1319,7 +1335,7 @@ def test_map_blocks_errors_bad_template(obj): def test_map_blocks_errors_bad_template_2(map_ds): - with raises_regex(ValueError, "unexpected data variables {'xyz'}"): + with pytest.raises(ValueError, match=r"unexpected data variables {'xyz'}"): xr.map_blocks(lambda x: x.assign(xyz=1), map_ds, template=map_ds).compute() @@ -1577,7 +1593,7 @@ def test_more_transforms_pass_lazy_array_equiv(map_da, map_ds): assert_equal(xr.broadcast(map_ds.cxy, map_ds.cxy)[0], map_ds.cxy) assert_equal(map_ds.map(lambda x: x), map_ds) assert_equal(map_ds.set_coords("a").reset_coords("a"), map_ds) - assert_equal(map_ds.update({"a": map_ds.a}), map_ds) + assert_equal(map_ds.assign({"a": map_ds.a}), map_ds) # fails because of index error # assert_equal( @@ -1599,3 +1615,38 @@ def test_optimize(): arr = xr.DataArray(a).chunk(5) (arr2,) = dask.optimize(arr) arr2.compute() + + +# The graph_manipulation module is in dask since 2021.2 but it became usable with +# xarray only since 2021.3 +@pytest.mark.skipif(dask_version <= "2021.02.0", reason="new module") +def test_graph_manipulation(): + """dask.graph_manipulation passes an optional parameter, "rename", to the rebuilder + function returned by __dask_postperist__; also, the dsk passed to the rebuilder is + a HighLevelGraph whereas with dask.persist() and dask.optimize() it's a plain dict. + """ + import dask.graph_manipulation as gm + + v = Variable(["x"], [1, 2]).chunk(-1).chunk(1) * 2 + da = DataArray(v) + ds = Dataset({"d1": v[0], "d2": v[1], "d3": ("x", [3, 4])}) + + v2, da2, ds2 = gm.clone(v, da, ds) + + assert_equal(v2, v) + assert_equal(da2, da) + assert_equal(ds2, ds) + + for a, b in ((v, v2), (da, da2), (ds, ds2)): + assert a.__dask_layers__() != b.__dask_layers__() + assert len(a.__dask_layers__()) == len(b.__dask_layers__()) + assert a.__dask_graph__().keys() != b.__dask_graph__().keys() + assert len(a.__dask_graph__()) == len(b.__dask_graph__()) + assert a.__dask_graph__().layers.keys() != b.__dask_graph__().layers.keys() + assert len(a.__dask_graph__().layers) == len(b.__dask_graph__().layers) + + # Above we performed a slice operation; adding the two slices back together creates + # a diamond-shaped dependency graph, which in turn will trigger a collision in layer + # names if we were to use HighLevelGraph.cull() instead of + # HighLevelGraph.cull_layers() in Dataset.__dask_postpersist__(). + assert_equal(ds2.d1 + ds2.d2, ds.d1 + ds.d2) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index fc84687511e..8ab8bc872da 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -7,6 +7,8 @@ import numpy as np import pandas as pd import pytest +from pandas.core.computation.ops import UndefinedVariableError +from pandas.tseries.frequencies import to_offset import xarray as xr from xarray import ( @@ -22,7 +24,7 @@ from xarray.convert import from_cdms2 from xarray.core import dtypes from xarray.core.common import full_like -from xarray.core.indexes import propagate_indexes +from xarray.core.indexes import Index, PandasIndex, propagate_indexes from xarray.core.utils import is_scalar from xarray.tests import ( LooseVersion, @@ -33,11 +35,13 @@ assert_identical, has_dask, raise_if_dask_computes, - raises_regex, requires_bottleneck, + requires_cupy, requires_dask, requires_iris, requires_numbagg, + requires_numexpr, + requires_pint_0_15, requires_scipy, requires_sparse, source_ndarray, @@ -129,7 +133,7 @@ def test_properties(self): with pytest.raises(AttributeError): self.dv.dataset assert isinstance(self.ds["x"].to_index(), pd.Index) - with raises_regex(ValueError, "must be 1-dimensional"): + with pytest.raises(ValueError, match=r"must be 1-dimensional"): self.ds["foo"].to_index() with pytest.raises(AttributeError): self.dv.variable = self.v @@ -145,10 +149,17 @@ def test_data_property(self): def test_indexes(self): array = DataArray(np.zeros((2, 3)), [("x", [0, 1]), ("y", ["a", "b", "c"])]) - expected = {"x": pd.Index([0, 1]), "y": pd.Index(["a", "b", "c"])} - assert array.indexes.keys() == expected.keys() - for k in expected: - assert array.indexes[k].equals(expected[k]) + expected_indexes = {"x": pd.Index([0, 1]), "y": pd.Index(["a", "b", "c"])} + expected_xindexes = { + k: PandasIndex(idx, k) for k, idx in expected_indexes.items() + } + assert array.xindexes.keys() == expected_xindexes.keys() + assert array.indexes.keys() == expected_indexes.keys() + assert all([isinstance(idx, pd.Index) for idx in array.indexes.values()]) + assert all([isinstance(idx, Index) for idx in array.xindexes.values()]) + for k in expected_indexes: + assert array.xindexes[k].equals(expected_xindexes[k]) + assert array.indexes[k].equals(expected_indexes[k]) def test_get_index(self): array = DataArray(np.zeros((2, 3)), coords={"x": ["a", "b"]}, dims=["x", "y"]) @@ -240,7 +251,7 @@ def test_dims(self): arr = self.dv assert arr.dims == ("x", "y") - with raises_regex(AttributeError, "you cannot assign"): + with pytest.raises(AttributeError, match=r"you cannot assign"): arr.dims = ("w", "z") def test_sizes(self): @@ -297,6 +308,9 @@ def test_constructor(self): actual = DataArray(data, coords, ["x", "y"]) assert_identical(expected, actual) + actual = DataArray(data, coords) + assert_identical(expected, actual) + coords = [("x", ["a", "b"]), ("y", [-1, -2, -3])] actual = DataArray(data, coords) assert_identical(expected, actual) @@ -325,31 +339,35 @@ def test_constructor(self): expected = Dataset({None: (["x", "y"], data, {}, {"bar": 2})})[None] assert_identical(expected, actual) + actual = DataArray([1, 2, 3], coords={"x": [0, 1, 2]}) + expected = DataArray([1, 2, 3], coords=[("x", [0, 1, 2])]) + assert_identical(expected, actual) + def test_constructor_invalid(self): data = np.random.randn(3, 2) - with raises_regex(ValueError, "coords is not dict-like"): + with pytest.raises(ValueError, match=r"coords is not dict-like"): DataArray(data, [[0, 1, 2]], ["x", "y"]) - with raises_regex(ValueError, "not a subset of the .* dim"): + with pytest.raises(ValueError, match=r"not a subset of the .* dim"): DataArray(data, {"x": [0, 1, 2]}, ["a", "b"]) - with raises_regex(ValueError, "not a subset of the .* dim"): + with pytest.raises(ValueError, match=r"not a subset of the .* dim"): DataArray(data, {"x": [0, 1, 2]}) - with raises_regex(TypeError, "is not a string"): + with pytest.raises(TypeError, match=r"is not a string"): DataArray(data, dims=["x", None]) - with raises_regex(ValueError, "conflicting sizes for dim"): + with pytest.raises(ValueError, match=r"conflicting sizes for dim"): DataArray([1, 2, 3], coords=[("x", [0, 1])]) - with raises_regex(ValueError, "conflicting sizes for dim"): + with pytest.raises(ValueError, match=r"conflicting sizes for dim"): DataArray([1, 2], coords={"x": [0, 1], "y": ("x", [1])}, dims="x") - with raises_regex(ValueError, "conflicting MultiIndex"): + with pytest.raises(ValueError, match=r"conflicting MultiIndex"): DataArray(np.random.rand(4, 4), [("x", self.mindex), ("y", self.mindex)]) - with raises_regex(ValueError, "conflicting MultiIndex"): + with pytest.raises(ValueError, match=r"conflicting MultiIndex"): DataArray(np.random.rand(4, 4), [("x", self.mindex), ("level_1", range(4))]) - with raises_regex(ValueError, "matching the dimension size"): + with pytest.raises(ValueError, match=r"matching the dimension size"): DataArray(data, coords={"x": 0}, dims=["x", "y"]) def test_constructor_from_self_described(self): @@ -689,7 +707,7 @@ def get_data(): da = get_data() # indexer with inconsistent coordinates. ind = DataArray(np.arange(1, 4), dims=["x"], coords={"x": np.random.randn(3)}) - with raises_regex(IndexError, "dimension coordinate 'x'"): + with pytest.raises(IndexError, match=r"dimension coordinate 'x'"): da[dict(x=ind)] = 0 # indexer with consistent coordinates. @@ -706,7 +724,7 @@ def get_data(): dims=["x", "y", "z"], coords={"x": [0, 1, 2], "non-dim": ("x", [0, 2, 4])}, ) - with raises_regex(IndexError, "dimension coordinate 'x'"): + with pytest.raises(IndexError, match=r"dimension coordinate 'x'"): da[dict(x=ind)] = value # consistent coordinate in the assigning values @@ -734,7 +752,7 @@ def get_data(): dims=["x", "y", "z"], coords={"x": [0, 1, 2], "non-dim": ("x", [0, 2, 4])}, ) - with raises_regex(IndexError, "dimension coordinate 'x'"): + with pytest.raises(IndexError, match=r"dimension coordinate 'x'"): da[dict(x=ind)] = value # consistent coordinate in the assigning values @@ -795,9 +813,9 @@ def test_isel(self): assert_identical(self.dv, self.dv.isel(x=slice(None))) assert_identical(self.dv[:3], self.dv.isel(x=slice(3))) assert_identical(self.dv[:3, :5], self.dv.isel(x=slice(3), y=slice(5))) - with raises_regex( + with pytest.raises( ValueError, - r"Dimensions {'not_a_dim'} do not exist. Expected " + match=r"Dimensions {'not_a_dim'} do not exist. Expected " r"one or more of \('x', 'y'\)", ): self.dv.isel(not_a_dim=0) @@ -868,7 +886,7 @@ def test_isel_fancy(self): ) # make sure we're raising errors in the right places - with raises_regex(IndexError, "Dimensions of indexers mismatch"): + with pytest.raises(IndexError, match=r"Dimensions of indexers mismatch"): da.isel(y=(("points",), [1, 2]), x=(("points",), [1, 2, 3])) # tests using index or DataArray as indexers @@ -882,7 +900,7 @@ def test_isel_fancy(self): assert "station" in actual.dims assert_identical(actual["station"], stations["station"]) - with raises_regex(ValueError, "conflicting values for "): + with pytest.raises(ValueError, match=r"conflicting values for "): da.isel( x=DataArray([0, 1, 2], dims="station", coords={"station": [0, 1, 2]}), y=DataArray([0, 1, 2], dims="station", coords={"station": [0, 1, 3]}), @@ -947,7 +965,7 @@ def test_sel_dataarray(self): def test_sel_invalid_slice(self): array = DataArray(np.arange(10), [("x", np.arange(10))]) - with raises_regex(ValueError, "cannot use non-scalar arrays"): + with pytest.raises(ValueError, match=r"cannot use non-scalar arrays"): array.sel(x=slice(array.x)) def test_sel_dataarray_datetime_slice(self): @@ -1040,11 +1058,11 @@ def test_head(self): assert_equal( self.dv.isel({dim: slice(5) for dim in self.dv.dims}), self.dv.head() ) - with raises_regex(TypeError, "either dict-like or a single int"): + with pytest.raises(TypeError, match=r"either dict-like or a single int"): self.dv.head([3]) - with raises_regex(TypeError, "expected integer type"): + with pytest.raises(TypeError, match=r"expected integer type"): self.dv.head(x=3.1) - with raises_regex(ValueError, "expected positive int"): + with pytest.raises(ValueError, match=r"expected positive int"): self.dv.head(-3) def test_tail(self): @@ -1057,11 +1075,11 @@ def test_tail(self): assert_equal( self.dv.isel({dim: slice(-5, None) for dim in self.dv.dims}), self.dv.tail() ) - with raises_regex(TypeError, "either dict-like or a single int"): + with pytest.raises(TypeError, match=r"either dict-like or a single int"): self.dv.tail([3]) - with raises_regex(TypeError, "expected integer type"): + with pytest.raises(TypeError, match=r"expected integer type"): self.dv.tail(x=3.1) - with raises_regex(ValueError, "expected positive int"): + with pytest.raises(ValueError, match=r"expected positive int"): self.dv.tail(-3) def test_thin(self): @@ -1070,13 +1088,13 @@ def test_thin(self): self.dv.isel({dim: slice(None, None, 6) for dim in self.dv.dims}), self.dv.thin(6), ) - with raises_regex(TypeError, "either dict-like or a single int"): + with pytest.raises(TypeError, match=r"either dict-like or a single int"): self.dv.thin([3]) - with raises_regex(TypeError, "expected integer type"): + with pytest.raises(TypeError, match=r"expected integer type"): self.dv.thin(x=3.1) - with raises_regex(ValueError, "expected positive int"): + with pytest.raises(ValueError, match=r"expected positive int"): self.dv.thin(-3) - with raises_regex(ValueError, "cannot be zero"): + with pytest.raises(ValueError, match=r"cannot be zero"): self.dv.thin(time=0) def test_loc(self): @@ -1134,7 +1152,7 @@ def get_data(): da = get_data() # indexer with inconsistent coordinates. ind = DataArray(np.arange(1, 4), dims=["y"], coords={"y": np.random.randn(3)}) - with raises_regex(IndexError, "dimension coordinate 'y'"): + with pytest.raises(IndexError, match=r"dimension coordinate 'y'"): da.loc[dict(x=ind)] = 0 # indexer with consistent coordinates. @@ -1151,7 +1169,7 @@ def get_data(): dims=["x", "y", "z"], coords={"x": [0, 1, 2], "non-dim": ("x", [0, 2, 4])}, ) - with raises_regex(IndexError, "dimension coordinate 'x'"): + with pytest.raises(IndexError, match=r"dimension coordinate 'x'"): da.loc[dict(x=ind)] = value # consistent coordinate in the assigning values @@ -1315,14 +1333,14 @@ def test_coords(self): expected = DataArray(da.values, {"y": [0, 1, 2]}, dims=["x", "y"], name="foo") assert_identical(da, expected) - with raises_regex(ValueError, "conflicting MultiIndex"): + with pytest.raises(ValueError, match=r"conflicting MultiIndex"): self.mda["level_1"] = np.arange(4) self.mda.coords["level_1"] = np.arange(4) def test_coords_to_index(self): da = DataArray(np.zeros((2, 3)), [("x", [1, 2]), ("y", list("abc"))]) - with raises_regex(ValueError, "no valid index"): + with pytest.raises(ValueError, match=r"no valid index"): da[0, 0].coords.to_index() expected = pd.Index(["a", "b", "c"], name="y") @@ -1341,7 +1359,7 @@ def test_coords_to_index(self): actual = da.coords.to_index(["y", "x"]) assert expected.equals(actual) - with raises_regex(ValueError, "ordered_dims must match"): + with pytest.raises(ValueError, match=r"ordered_dims must match"): da.coords.to_index(["x"]) def test_coord_coords(self): @@ -1415,11 +1433,11 @@ def test_reset_coords(self): ) assert_identical(actual, expected) - with raises_regex(ValueError, "cannot be found"): + with pytest.raises(ValueError, match=r"cannot be found"): data.reset_coords("foo", drop=True) - with raises_regex(ValueError, "cannot be found"): + with pytest.raises(ValueError, match=r"cannot be found"): data.reset_coords("not_found") - with raises_regex(ValueError, "cannot remove index"): + with pytest.raises(ValueError, match=r"cannot remove index"): data.reset_coords("y") def test_assign_coords(self): @@ -1434,7 +1452,7 @@ def test_assign_coords(self): expected.coords["d"] = ("x", [1.5, 1.5, 3.5, 3.5]) assert_identical(actual, expected) - with raises_regex(ValueError, "conflicting MultiIndex"): + with pytest.raises(ValueError, match=r"conflicting MultiIndex"): self.mda.assign_coords(level_1=range(4)) # GH: 2112 @@ -1457,7 +1475,7 @@ def test_coords_alignment(self): def test_set_coords_update_index(self): actual = DataArray([1, 2, 3], [("x", [1, 2, 3])]) actual.coords["x"] = ["a", "b", "c"] - assert actual.indexes["x"].equals(pd.Index(["a", "b", "c"])) + assert actual.xindexes["x"].to_pandas_index().equals(pd.Index(["a", "b", "c"])) def test_coords_replacement_alignment(self): # regression test for GH725 @@ -1477,7 +1495,7 @@ def test_coords_delitem_delete_indexes(self): # regression test for GH3746 arr = DataArray(np.ones((2,)), dims="x", coords={"x": [0, 1]}) del arr.coords["x"] - assert "x" not in arr.indexes + assert "x" not in arr.xindexes def test_broadcast_like(self): arr1 = DataArray( @@ -1519,7 +1537,7 @@ def test_reindex_like_no_index(self): assert_identical(foo, foo.reindex_like(foo)) bar = foo[:4] - with raises_regex(ValueError, "different size for unlabeled"): + with pytest.raises(ValueError, match=r"different size for unlabeled"): foo.reindex_like(bar) def test_reindex_regressions(self): @@ -1615,28 +1633,20 @@ def test_init_value(self): actual = DataArray(coords=[("x", np.arange(10)), ("y", ["a", "b"])]) assert_identical(expected, actual) - with raises_regex(ValueError, "different number of dim"): + with pytest.raises(ValueError, match=r"different number of dim"): DataArray(np.array(1), coords={"x": np.arange(10)}, dims=["x"]) - with raises_regex(ValueError, "does not match the 0 dim"): + with pytest.raises(ValueError, match=r"does not match the 0 dim"): DataArray(np.array(1), coords=[("x", np.arange(10))]) def test_swap_dims(self): - array = DataArray(np.random.randn(3), {"y": ("x", list("abc"))}, "x") - expected = DataArray(array.values, {"y": list("abc")}, dims="y") - actual = array.swap_dims({"x": "y"}) - assert_identical(expected, actual) - for dim_name in set().union(expected.indexes.keys(), actual.indexes.keys()): - pd.testing.assert_index_equal( - expected.indexes[dim_name], actual.indexes[dim_name] - ) - array = DataArray(np.random.randn(3), {"x": list("abc")}, "x") expected = DataArray(array.values, {"x": ("y", list("abc"))}, dims="y") actual = array.swap_dims({"x": "y"}) assert_identical(expected, actual) - for dim_name in set().union(expected.indexes.keys(), actual.indexes.keys()): + for dim_name in set().union(expected.xindexes.keys(), actual.xindexes.keys()): pd.testing.assert_index_equal( - expected.indexes[dim_name], actual.indexes[dim_name] + expected.xindexes[dim_name].to_pandas_index(), + actual.xindexes[dim_name].to_pandas_index(), ) # as kwargs @@ -1644,9 +1654,10 @@ def test_swap_dims(self): expected = DataArray(array.values, {"x": ("y", list("abc"))}, dims="y") actual = array.swap_dims(x="y") assert_identical(expected, actual) - for dim_name in set().union(expected.indexes.keys(), actual.indexes.keys()): + for dim_name in set().union(expected.xindexes.keys(), actual.xindexes.keys()): pd.testing.assert_index_equal( - expected.indexes[dim_name], actual.indexes[dim_name] + expected.xindexes[dim_name].to_pandas_index(), + actual.xindexes[dim_name].to_pandas_index(), ) # multiindex case @@ -1655,9 +1666,10 @@ def test_swap_dims(self): expected = DataArray(array.values, {"y": idx}, "y") actual = array.swap_dims({"x": "y"}) assert_identical(expected, actual) - for dim_name in set().union(expected.indexes.keys(), actual.indexes.keys()): + for dim_name in set().union(expected.xindexes.keys(), actual.xindexes.keys()): pd.testing.assert_index_equal( - expected.indexes[dim_name], actual.indexes[dim_name] + expected.xindexes[dim_name].to_pandas_index(), + actual.xindexes[dim_name].to_pandas_index(), ) def test_expand_dims_error(self): @@ -1668,20 +1680,20 @@ def test_expand_dims_error(self): attrs={"key": "entry"}, ) - with raises_regex(TypeError, "dim should be hashable or"): + with pytest.raises(TypeError, match=r"dim should be hashable or"): array.expand_dims(0) - with raises_regex(ValueError, "lengths of dim and axis"): + with pytest.raises(ValueError, match=r"lengths of dim and axis"): # dims and axis argument should be the same length array.expand_dims(dim=["a", "b"], axis=[1, 2, 3]) - with raises_regex(ValueError, "Dimension x already"): + with pytest.raises(ValueError, match=r"Dimension x already"): # Should not pass the already existing dimension. array.expand_dims(dim=["x"]) # raise if duplicate - with raises_regex(ValueError, "duplicate values"): + with pytest.raises(ValueError, match=r"duplicate values"): array.expand_dims(dim=["y", "y"]) - with raises_regex(ValueError, "duplicate values"): + with pytest.raises(ValueError, match=r"duplicate values"): array.expand_dims(dim=["y", "z"], axis=[1, 1]) - with raises_regex(ValueError, "duplicate values"): + with pytest.raises(ValueError, match=r"duplicate values"): array.expand_dims(dim=["y", "z"], axis=[2, -2]) # out of bounds error, axis must be in [-4, 3] @@ -1847,7 +1859,7 @@ def test_set_index(self): coords={"x": ("x", [0, 1]), "level": ("y", [1, 2])}, dims=("x", "y"), ) - with raises_regex(ValueError, "dimension mismatch"): + with pytest.raises(ValueError, match=r"dimension mismatch"): array2d.set_index(x="level") # Issue 3176: Ensure clear error message on key error. @@ -1907,7 +1919,7 @@ def test_reorder_levels(self): array.reorder_levels(x=["level_1", "level_2"]) array["x"] = [0, 1] - with raises_regex(ValueError, "has no MultiIndex"): + with pytest.raises(ValueError, match=r"has no MultiIndex"): array.reorder_levels(x=["level_1", "level_2"]) def test_dataset_getitem(self): @@ -2208,12 +2220,10 @@ def test_unstack_pandas_consistency(self): actual = DataArray(s, dims="z").unstack("z") assert_identical(expected, actual) - def test_stack_nonunique_consistency(self): - orig = DataArray( - [[0, 1], [2, 3]], dims=["x", "y"], coords={"x": [0, 1], "y": [0, 0]} - ) - actual = orig.stack(z=["x", "y"]) - expected = DataArray(orig.to_pandas().stack(), dims="z") + def test_stack_nonunique_consistency(self, da): + da = da.isel(time=0, drop=True) # 2D + actual = da.stack(z=["a", "x"]) + expected = DataArray(da.to_pandas().stack(), dims="z") assert_identical(expected, actual) def test_to_unstacked_dataset_raises_value_error(self): @@ -2309,14 +2319,14 @@ def test_drop_coordinates(self): actual = expected.drop_vars("not found", errors="ignore") assert_identical(actual, expected) - with raises_regex(ValueError, "cannot be found"): + with pytest.raises(ValueError, match=r"cannot be found"): arr.drop_vars("w") actual = expected.drop_vars("w", errors="ignore") assert_identical(actual, expected) renamed = arr.rename("foo") - with raises_regex(ValueError, "cannot be found"): + with pytest.raises(ValueError, match=r"cannot be found"): renamed.drop_vars("foo") actual = renamed.drop_vars("foo", errors="ignore") @@ -2328,7 +2338,7 @@ def test_drop_index_labels(self): expected = arr[:, 2:] assert_identical(actual, expected) - with raises_regex((KeyError, ValueError), "not .* in axis"): + with pytest.raises((KeyError, ValueError), match=r"not .* in axis"): actual = arr.drop_sel(y=[0, 1, 3]) actual = arr.drop_sel(y=[0, 1, 3], errors="ignore") @@ -2598,10 +2608,10 @@ def test_fillna(self): actual = a.fillna(b[:0]) assert_identical(a, actual) - with raises_regex(TypeError, "fillna on a DataArray"): + with pytest.raises(TypeError, match=r"fillna on a DataArray"): a.fillna({0: 0}) - with raises_regex(ValueError, "broadcast"): + with pytest.raises(ValueError, match=r"broadcast"): a.fillna([1, 2]) fill_value = DataArray([0, 1], dims="y") @@ -2816,11 +2826,11 @@ def test_groupby_math(self): actual_agg = actual.groupby("abc").mean(...) assert_allclose(expected_agg, actual_agg) - with raises_regex(TypeError, "only support binary ops"): + with pytest.raises(TypeError, match=r"only support binary ops"): grouped + 1 - with raises_regex(TypeError, "only support binary ops"): + with pytest.raises(TypeError, match=r"only support binary ops"): grouped + grouped - with raises_regex(TypeError, "in-place operations"): + with pytest.raises(TypeError, match=r"in-place operations"): array += grouped def test_groupby_math_not_aligned(self): @@ -2990,11 +3000,15 @@ def test_resample(self): actual = array.resample(time="24H").reduce(np.mean) assert_identical(expected, actual) + # Our use of `loffset` may change if we align our API with pandas' changes. + # ref https://github.com/pydata/xarray/pull/4537 actual = array.resample(time="24H", loffset="-12H").mean() - expected = DataArray(array.to_series().resample("24H", loffset="-12H").mean()) - assert_identical(expected, actual) + expected_ = array.to_series().resample("24H").mean() + expected_.index += to_offset("-12H") + expected = DataArray.from_series(expected_) + assert_identical(actual, expected) - with raises_regex(ValueError, "index must be monotonic"): + with pytest.raises(ValueError, match=r"index must be monotonic"): array[[2, 0, 1]].resample(time="1D") def test_da_resample_func_args(self): @@ -3043,7 +3057,7 @@ def test_resample_first(self): def test_resample_bad_resample_dim(self): times = pd.date_range("2000-01-01", freq="6H", periods=10) array = DataArray(np.arange(10), [("__resample_dim__", times)]) - with raises_regex(ValueError, "Proxy resampling dimension"): + with pytest.raises(ValueError, match=r"Proxy resampling dimension"): array.resample(**{"__resample_dim__": "1D"}).first() @requires_scipy @@ -3087,6 +3101,11 @@ def test_resample_keep_attrs(self): expected = DataArray([1, 1, 1], [("time", times[::4])], attrs=array.attrs) assert_identical(result, expected) + with pytest.warns( + UserWarning, match="Passing ``keep_attrs`` to ``resample`` has no effect." + ): + array.resample(time="1D", keep_attrs=True) + def test_resample_skipna(self): times = pd.date_range("2000-01-01", freq="6H", periods=10) array = DataArray(np.ones(10), [("time", times)]) @@ -3380,7 +3399,7 @@ def test_align_override(self): assert_identical(left.isel(x=0, drop=True), new_left) assert_identical(right, new_right) - with raises_regex(ValueError, "Indexes along dimension 'x' don't have"): + with pytest.raises(ValueError, match=r"Indexes along dimension 'x' don't have"): align(left.isel(x=0).expand_dims("x"), right, join="override") @pytest.mark.parametrize( @@ -3399,7 +3418,7 @@ def test_align_override(self): ], ) def test_align_override_error(self, darrays): - with raises_regex(ValueError, "Indexes along dimension 'x' don't have"): + with pytest.raises(ValueError, match=r"Indexes along dimension 'x' don't have"): xr.align(*darrays, join="override") def test_align_exclude(self): @@ -3455,10 +3474,10 @@ def test_align_mixed_indexes(self): assert_identical(result1, array_with_coord) def test_align_without_indexes_errors(self): - with raises_regex(ValueError, "cannot be aligned"): + with pytest.raises(ValueError, match=r"cannot be aligned"): align(DataArray([1, 2, 3], dims=["x"]), DataArray([1, 2], dims=["x"])) - with raises_regex(ValueError, "cannot be aligned"): + with pytest.raises(ValueError, match=r"cannot be aligned"): align( DataArray([1, 2, 3], dims=["x"]), DataArray([1, 2], coords=[("x", [0, 1])]), @@ -3598,7 +3617,7 @@ def test_to_pandas(self): roundtripped = DataArray(da.to_pandas()).drop_vars(dims) assert_identical(da, roundtripped) - with raises_regex(ValueError, "cannot convert"): + with pytest.raises(ValueError, match=r"cannot convert"): DataArray(np.random.randn(1, 2, 3, 4, 5)).to_pandas() def test_to_dataframe(self): @@ -3632,9 +3651,36 @@ def test_to_dataframe(self): arr.sel(A="c", B=2).to_dataframe() arr.name = None # unnamed - with raises_regex(ValueError, "unnamed"): + with pytest.raises(ValueError, match=r"unnamed"): arr.to_dataframe() + def test_to_dataframe_multiindex(self): + # regression test for #3008 + arr_np = np.random.randn(4, 3) + + mindex = pd.MultiIndex.from_product([[1, 2], list("ab")], names=["A", "B"]) + + arr = DataArray(arr_np, [("MI", mindex), ("C", [5, 6, 7])], name="foo") + + actual = arr.to_dataframe() + assert_array_equal(actual["foo"].values, arr_np.flatten()) + assert_array_equal(actual.index.names, list("ABC")) + assert_array_equal(actual.index.levels[0], [1, 2]) + assert_array_equal(actual.index.levels[1], ["a", "b"]) + assert_array_equal(actual.index.levels[2], [5, 6, 7]) + + def test_to_dataframe_0length(self): + # regression test for #3008 + arr_np = np.random.randn(4, 0) + + mindex = pd.MultiIndex.from_product([[1, 2], list("ab")], names=["A", "B"]) + + arr = DataArray(arr_np, [("MI", mindex), ("C", [])], name="foo") + + actual = arr.to_dataframe() + assert len(actual) == 0 + assert_array_equal(actual.index.names, list("ABC")) + def test_to_pandas_name_matches_coordinate(self): # coordinate with same name as array arr = DataArray([1, 2, 3], dims="x", name="x") @@ -3758,14 +3804,17 @@ def test_to_and_from_dict(self): "data": array.values, "coords": {"x": {"data": ["a", "b"]}}, } - with raises_regex( - ValueError, "cannot convert dict when coords are missing the key 'dims'" + with pytest.raises( + ValueError, + match=r"cannot convert dict when coords are missing the key 'dims'", ): DataArray.from_dict(d) # this one is missing some necessary information d = {"dims": ("t")} - with raises_regex(ValueError, "cannot convert dict without the key 'data'"): + with pytest.raises( + ValueError, match=r"cannot convert dict without the key 'data'" + ): DataArray.from_dict(d) # check the data=False option @@ -3950,7 +3999,7 @@ def test_to_and_from_cdms2_ugrid(self): def test_to_dataset_whole(self): unnamed = DataArray([1, 2], dims="x") - with raises_regex(ValueError, "unable to convert unnamed"): + with pytest.raises(ValueError, match=r"unable to convert unnamed"): unnamed.to_dataset() actual = unnamed.to_dataset(name="foo") @@ -4145,11 +4194,11 @@ def test_real_and_imag(self): def test_setattr_raises(self): array = DataArray(0, coords={"scalar": 1}, attrs={"foo": "bar"}) - with raises_regex(AttributeError, "cannot set attr"): + with pytest.raises(AttributeError, match=r"cannot set attr"): array.scalar = 2 - with raises_regex(AttributeError, "cannot set attr"): + with pytest.raises(AttributeError, match=r"cannot set attr"): array.foo = 2 - with raises_regex(AttributeError, "cannot set attr"): + with pytest.raises(AttributeError, match=r"cannot set attr"): array.other = 2 def test_full_like(self): @@ -4173,6 +4222,9 @@ def test_full_like(self): assert expect.dtype == bool assert_identical(expect, actual) + with pytest.raises(ValueError, match="'dtype' cannot be dict-like"): + full_like(da, fill_value=True, dtype={"x": bool}) + def test_dot(self): x = np.linspace(-3, 3, 6) y = np.linspace(-3, 3, 5) @@ -4233,7 +4285,7 @@ def test_dot_align_coords(self): dm = DataArray(dm_vals, coords=[z_m], dims=["z"]) with xr.set_options(arithmetic_join="exact"): - with raises_regex(ValueError, "indexes along dimension"): + with pytest.raises(ValueError, match=r"indexes along dimension"): da.dot(dm) da_aligned, dm_aligned = xr.align(da, dm, join="inner") @@ -4284,18 +4336,18 @@ def test_matmul_align_coords(self): assert_identical(result, expected) with xr.set_options(arithmetic_join="exact"): - with raises_regex(ValueError, "indexes along dimension"): + with pytest.raises(ValueError, match=r"indexes along dimension"): da_a @ da_b def test_binary_op_propagate_indexes(self): # regression test for GH2227 self.dv["x"] = np.arange(self.dv.sizes["x"]) - expected = self.dv.indexes["x"] + expected = self.dv.xindexes["x"] - actual = (self.dv * 10).indexes["x"] + actual = (self.dv * 10).xindexes["x"] assert expected is actual - actual = (self.dv > 10).indexes["x"] + actual = (self.dv > 10).xindexes["x"] assert expected is actual def test_binary_op_join_setting(self): @@ -4400,6 +4452,7 @@ def test_rank(self): @pytest.mark.parametrize("use_dask", [True, False]) @pytest.mark.parametrize("use_datetime", [True, False]) + @pytest.mark.filterwarnings("ignore:overflow encountered in multiply") def test_polyfit(self, use_dask, use_datetime): if use_dask and not has_dask: pytest.skip("requires dask") @@ -4472,6 +4525,26 @@ def test_pad_constant(self): assert actual.shape == (7, 4, 5) assert_identical(actual, expected) + ar = xr.DataArray([9], dims="x") + + actual = ar.pad(x=1) + expected = xr.DataArray([np.NaN, 9, np.NaN], dims="x") + assert_identical(actual, expected) + + actual = ar.pad(x=1, constant_values=1.23456) + expected = xr.DataArray([1, 9, 1], dims="x") + assert_identical(actual, expected) + + if LooseVersion(np.__version__) >= "1.20": + with pytest.raises(ValueError, match="cannot convert float NaN to integer"): + ar.pad(x=1, constant_values=np.NaN) + else: + actual = ar.pad(x=1, constant_values=np.NaN) + expected = xr.DataArray( + [-9223372036854775808, 9, -9223372036854775808], dims="x" + ) + assert_identical(actual, expected) + def test_pad_coords(self): ar = DataArray( np.arange(3 * 4 * 5).reshape(3, 4, 5), @@ -4565,6 +4638,128 @@ def test_pad_reflect(self, mode, reflect_type): assert actual.shape == (7, 4, 9) assert_identical(actual, expected) + @pytest.mark.parametrize("parser", ["pandas", "python"]) + @pytest.mark.parametrize( + "engine", ["python", None, pytest.param("numexpr", marks=[requires_numexpr])] + ) + @pytest.mark.parametrize( + "backend", ["numpy", pytest.param("dask", marks=[requires_dask])] + ) + def test_query(self, backend, engine, parser): + """Test querying a dataset.""" + + # setup test data + np.random.seed(42) + a = np.arange(0, 10, 1) + b = np.random.randint(0, 100, size=10) + c = np.linspace(0, 1, 20) + d = np.random.choice(["foo", "bar", "baz"], size=30, replace=True).astype( + object + ) + if backend == "numpy": + aa = DataArray(data=a, dims=["x"], name="a") + bb = DataArray(data=b, dims=["x"], name="b") + cc = DataArray(data=c, dims=["y"], name="c") + dd = DataArray(data=d, dims=["z"], name="d") + + elif backend == "dask": + import dask.array as da + + aa = DataArray(data=da.from_array(a, chunks=3), dims=["x"], name="a") + bb = DataArray(data=da.from_array(b, chunks=3), dims=["x"], name="b") + cc = DataArray(data=da.from_array(c, chunks=7), dims=["y"], name="c") + dd = DataArray(data=da.from_array(d, chunks=12), dims=["z"], name="d") + + # query single dim, single variable + actual = aa.query(x="a > 5", engine=engine, parser=parser) + expect = aa.isel(x=(a > 5)) + assert_identical(expect, actual) + + # query single dim, single variable, via dict + actual = aa.query(dict(x="a > 5"), engine=engine, parser=parser) + expect = aa.isel(dict(x=(a > 5))) + assert_identical(expect, actual) + + # query single dim, single variable + actual = bb.query(x="b > 50", engine=engine, parser=parser) + expect = bb.isel(x=(b > 50)) + assert_identical(expect, actual) + + # query single dim, single variable + actual = cc.query(y="c < .5", engine=engine, parser=parser) + expect = cc.isel(y=(c < 0.5)) + assert_identical(expect, actual) + + # query single dim, single string variable + if parser == "pandas": + # N.B., this query currently only works with the pandas parser + # xref https://github.com/pandas-dev/pandas/issues/40436 + actual = dd.query(z='d == "bar"', engine=engine, parser=parser) + expect = dd.isel(z=(d == "bar")) + assert_identical(expect, actual) + + # test error handling + with pytest.raises(ValueError): + aa.query("a > 5") # must be dict or kwargs + with pytest.raises(ValueError): + aa.query(x=(a > 5)) # must be query string + with pytest.raises(UndefinedVariableError): + aa.query(x="spam > 50") # name not present + + @requires_scipy + @pytest.mark.parametrize("use_dask", [True, False]) + def test_curvefit(self, use_dask): + if use_dask and not has_dask: + pytest.skip("requires dask") + + def exp_decay(t, n0, tau=1): + return n0 * np.exp(-t / tau) + + t = np.arange(0, 5, 0.5) + da = DataArray( + np.stack([exp_decay(t, 3, 3), exp_decay(t, 5, 4), np.nan * t], axis=-1), + dims=("t", "x"), + coords={"t": t, "x": [0, 1, 2]}, + ) + da[0, 0] = np.nan + + expected = DataArray( + [[3, 3], [5, 4], [np.nan, np.nan]], + dims=("x", "param"), + coords={"x": [0, 1, 2], "param": ["n0", "tau"]}, + ) + + if use_dask: + da = da.chunk({"x": 1}) + + fit = da.curvefit( + coords=[da.t], func=exp_decay, p0={"n0": 4}, bounds={"tau": [2, 6]} + ) + assert_allclose(fit.curvefit_coefficients, expected, rtol=1e-3) + + da = da.compute() + fit = da.curvefit(coords="t", func=np.power, reduce_dims="x", param_names=["a"]) + assert "a" in fit.param + assert "x" not in fit.dims + + def test_curvefit_helpers(self): + def exp_decay(t, n0, tau=1): + return n0 * np.exp(-t / tau) + + params, func_args = xr.core.dataset._get_func_args(exp_decay, []) + assert params == ["n0", "tau"] + param_defaults, bounds_defaults = xr.core.dataset._initialize_curvefit_params( + params, {"n0": 4}, {"tau": [5, np.inf]}, func_args + ) + assert param_defaults == {"n0": 4, "tau": 6} + assert bounds_defaults == {"n0": (-np.inf, np.inf), "tau": (5, np.inf)} + + param_names = ["a"] + params, func_args = xr.core.dataset._get_func_args(np.power, param_names) + assert params == param_names + with pytest.raises(ValueError): + xr.core.dataset._get_func_args(np.power, []) + class TestReduce: @pytest.fixture(autouse=True) @@ -6256,34 +6451,31 @@ def test_idxminmax_dask(self, op, ndim): @pytest.fixture(params=[1]) -def da(request): +def da(request, backend): if request.param == 1: times = pd.date_range("2000-01-01", freq="1D", periods=21) - values = np.random.random((3, 21, 4)) - da = DataArray(values, dims=("a", "time", "x")) - da["time"] = times - return da + da = DataArray( + np.random.random((3, 21, 4)), + dims=("a", "time", "x"), + coords=dict(time=times), + ) if request.param == 2: - return DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time") + da = DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time") if request.param == "repeating_ints": - return DataArray( + da = DataArray( np.tile(np.arange(12), 5).reshape(5, 4, 3), coords={"x": list("abc"), "y": list("defg")}, dims=list("zyx"), ) - -@pytest.fixture -def da_dask(seed=123): - pytest.importorskip("dask.array") - rs = np.random.RandomState(seed) - times = pd.date_range("2000-01-01", freq="1D", periods=21) - values = rs.normal(size=(1, 21, 1)) - da = DataArray(values, dims=("a", "time", "x")).chunk({"time": 7}) - da["time"] = times - return da + if backend == "dask": + return da.chunk() + elif backend == "numpy": + return da + else: + raise ValueError @pytest.mark.parametrize("da", ("repeating_ints",), indirect=True) @@ -6306,35 +6498,6 @@ def test_isin(da): assert_equal(result, expected) -def test_coarsen_keep_attrs(): - _attrs = {"units": "test", "long_name": "testing"} - - da = xr.DataArray( - np.linspace(0, 364, num=364), - dims="time", - coords={"time": pd.date_range("15/12/1999", periods=364)}, - attrs=_attrs, - ) - - da2 = da.copy(deep=True) - - # Test dropped attrs - dat = da.coarsen(time=3, boundary="trim").mean() - assert dat.attrs == {} - - # Test kept attrs using dataset keyword - dat = da.coarsen(time=3, boundary="trim", keep_attrs=True).mean() - assert dat.attrs == _attrs - - # Test kept attrs using global option - with xr.set_options(keep_attrs=True): - dat = da.coarsen(time=3, boundary="trim").mean() - assert dat.attrs == _attrs - - # Test kept attrs in original object - xr.testing.assert_identical(da, da2) - - @pytest.mark.parametrize("da", (1, 2), indirect=True) def test_rolling_iter(da): rolling_obj = da.rolling(time=7) @@ -6369,6 +6532,15 @@ def test_rolling_repr(da): assert repr(rolling_obj) == "DataArrayRolling [time->7(center),x->3(center)]" +@requires_dask +def test_repeated_rolling_rechunks(): + + # regression test for GH3277, GH2514 + dat = DataArray(np.random.rand(7653, 300), dims=("day", "item")) + dat_chunk = dat.chunk({"item": 20}) + dat_chunk.rolling(day=10).mean().rolling(day=250).std() + + def test_rolling_doc(da): rolling_obj = da.rolling(time=7) @@ -6392,6 +6564,7 @@ def test_rolling_properties(da): @pytest.mark.parametrize("name", ("sum", "mean", "std", "min", "max", "median")) @pytest.mark.parametrize("center", (True, False, None)) @pytest.mark.parametrize("min_periods", (1, None)) +@pytest.mark.parametrize("backend", ["numpy"], indirect=True) def test_rolling_wrapped_bottleneck(da, name, center, min_periods): bn = pytest.importorskip("bottleneck", minversion="1.1") @@ -6419,17 +6592,16 @@ def test_rolling_wrapped_bottleneck(da, name, center, min_periods): @pytest.mark.parametrize("center", (True, False, None)) @pytest.mark.parametrize("min_periods", (1, None)) @pytest.mark.parametrize("window", (7, 8)) -def test_rolling_wrapped_dask(da_dask, name, center, min_periods, window): +@pytest.mark.parametrize("backend", ["dask"], indirect=True) +def test_rolling_wrapped_dask(da, name, center, min_periods, window): # dask version - rolling_obj = da_dask.rolling(time=window, min_periods=min_periods, center=center) + rolling_obj = da.rolling(time=window, min_periods=min_periods, center=center) actual = getattr(rolling_obj, name)().load() if name != "count": with pytest.warns(DeprecationWarning, match="Reductions are applied"): getattr(rolling_obj, name)(dim="time") # numpy version - rolling_obj = da_dask.load().rolling( - time=window, min_periods=min_periods, center=center - ) + rolling_obj = da.load().rolling(time=window, min_periods=min_periods, center=center) expected = getattr(rolling_obj, name)() # using all-close because rolling over ghost cells introduces some @@ -6437,7 +6609,7 @@ def test_rolling_wrapped_dask(da_dask, name, center, min_periods, window): assert_allclose(actual, expected) # with zero chunked array GH:2113 - rolling_obj = da_dask.chunk().rolling( + rolling_obj = da.chunk().rolling( time=window, min_periods=min_periods, center=center ) actual = getattr(rolling_obj, name)().load() @@ -6603,6 +6775,16 @@ def test_ndrolling_reduce(da, center, min_periods, name): assert_allclose(actual, expected) assert actual.dims == expected.dims + if name in ["mean"]: + # test our reimplementation of nanmean using np.nanmean + expected = getattr(rolling_obj.construct({"time": "tw", "x": "xw"}), name)( + ["tw", "xw"] + ) + count = rolling_obj.count() + if min_periods is None: + min_periods = 1 + assert_allclose(actual, expected.where(count >= min_periods)) + @pytest.mark.parametrize("center", (True, False, (True, False))) @pytest.mark.parametrize("fill_value", (np.nan, 0.0)) @@ -6678,38 +6860,6 @@ def test_rolling_keep_attrs(funcname, argument): assert result.name == "name" -def test_rolling_keep_attrs_deprecated(): - attrs_da = {"da_attr": "test"} - - data = np.linspace(10, 15, 100) - coords = np.linspace(1, 10, 100) - - da = DataArray( - data, - dims=("coord"), - coords={"coord": coords}, - attrs=attrs_da, - ) - - # deprecated option - with pytest.warns( - FutureWarning, match="Passing ``keep_attrs`` to ``rolling`` is deprecated" - ): - result = da.rolling(dim={"coord": 5}, keep_attrs=False).construct("window_dim") - - assert result.attrs == {} - - # the keep_attrs in the reduction function takes precedence - with pytest.warns( - FutureWarning, match="Passing ``keep_attrs`` to ``rolling`` is deprecated" - ): - result = da.rolling(dim={"coord": 5}, keep_attrs=True).construct( - "window_dim", keep_attrs=False - ) - - assert result.attrs == {} - - def test_raise_no_warning_for_nan_in_binary_ops(): with pytest.warns(None) as record: xr.DataArray([1, 2, np.NaN]) > 0 @@ -6992,10 +7142,33 @@ def test_fallback_to_iris_AuxCoord(self, coord_values): @pytest.mark.parametrize( "window_type, window", [["span", 5], ["alpha", 0.5], ["com", 0.5], ["halflife", 5]] ) -def test_rolling_exp(da, dim, window_type, window): - da = da.isel(a=0) +@pytest.mark.parametrize("backend", ["numpy"], indirect=True) +@pytest.mark.parametrize("func", ["mean", "sum"]) +def test_rolling_exp_runs(da, dim, window_type, window, func): + import numbagg + + if ( + LooseVersion(getattr(numbagg, "__version__", "0.1.0")) < "0.2.1" + and func == "sum" + ): + pytest.skip("rolling_exp.sum requires numbagg 0.2.1") + da = da.where(da > 0.2) + rolling_exp = da.rolling_exp(window_type=window_type, **{dim: window}) + result = getattr(rolling_exp, func)() + assert isinstance(result, DataArray) + + +@requires_numbagg +@pytest.mark.parametrize("dim", ["time", "x"]) +@pytest.mark.parametrize( + "window_type, window", [["span", 5], ["alpha", 0.5], ["com", 0.5], ["halflife", 5]] +) +@pytest.mark.parametrize("backend", ["numpy"], indirect=True) +def test_rolling_exp_mean_pandas(da, dim, window_type, window): + da = da.isel(a=0).where(lambda x: x > 0.2) + result = da.rolling_exp(window_type=window_type, **{dim: window}).mean() assert isinstance(result, DataArray) @@ -7011,32 +7184,50 @@ def test_rolling_exp(da, dim, window_type, window): @requires_numbagg -def test_rolling_exp_keep_attrs(da): +@pytest.mark.parametrize("backend", ["numpy"], indirect=True) +@pytest.mark.parametrize("func", ["mean", "sum"]) +def test_rolling_exp_keep_attrs(da, func): + import numbagg + + if ( + LooseVersion(getattr(numbagg, "__version__", "0.1.0")) < "0.2.1" + and func == "sum" + ): + pytest.skip("rolling_exp.sum requires numbagg 0.2.1") + attrs = {"attrs": "da"} da.attrs = attrs + # Equivalent of `da.rolling_exp(time=10).mean` + rolling_exp_func = getattr(da.rolling_exp(time=10), func) + # attrs are kept per default - result = da.rolling_exp(time=10).mean() + result = rolling_exp_func() assert result.attrs == attrs # discard attrs - result = da.rolling_exp(time=10).mean(keep_attrs=False) + result = rolling_exp_func(keep_attrs=False) assert result.attrs == {} # test discard attrs using global option with set_options(keep_attrs=False): - result = da.rolling_exp(time=10).mean() + result = rolling_exp_func() assert result.attrs == {} # keyword takes precedence over global option with set_options(keep_attrs=False): - result = da.rolling_exp(time=10).mean(keep_attrs=True) + result = rolling_exp_func(keep_attrs=True) assert result.attrs == attrs with set_options(keep_attrs=True): - result = da.rolling_exp(time=10).mean(keep_attrs=False) + result = rolling_exp_func(keep_attrs=False) assert result.attrs == {} + with pytest.warns( + UserWarning, match="Passing ``keep_attrs`` to ``rolling_exp`` has no effect." + ): + da.rolling_exp(time=10, keep_attrs=True) + def test_no_dict(): d = DataArray() @@ -7099,3 +7290,140 @@ def test_deepcopy_obj_array(): x0 = DataArray(np.array([object()])) x1 = deepcopy(x0) assert x0.values[0] is not x1.values[0] + + +def test_clip(da): + with raise_if_dask_computes(): + result = da.clip(min=0.5) + assert result.min(...) >= 0.5 + + result = da.clip(max=0.5) + assert result.max(...) <= 0.5 + + result = da.clip(min=0.25, max=0.75) + assert result.min(...) >= 0.25 + assert result.max(...) <= 0.75 + + with raise_if_dask_computes(): + result = da.clip(min=da.mean("x"), max=da.mean("a")) + assert result.dims == da.dims + assert_array_equal( + result.data, + np.clip(da.data, da.mean("x").data[:, :, np.newaxis], da.mean("a").data), + ) + + with_nans = da.isel(time=[0, 1]).reindex_like(da) + with raise_if_dask_computes(): + result = da.clip(min=da.mean("x"), max=da.mean("a")) + result = da.clip(with_nans) + # The values should be the same where there were NaNs. + assert_array_equal(result.isel(time=[0, 1]), with_nans.isel(time=[0, 1])) + + # Unclear whether we want this work, OK to adjust the test when we have decided. + with pytest.raises(ValueError, match="arguments without labels along dimension"): + result = da.clip(min=da.mean("x"), max=da.mean("a").isel(x=[0, 1])) + + +@pytest.mark.parametrize("keep", ["first", "last", False]) +def test_drop_duplicates(keep): + ds = xr.DataArray( + [0, 5, 6, 7], dims="time", coords={"time": [0, 0, 1, 2]}, name="test" + ) + + if keep == "first": + data = [0, 6, 7] + time = [0, 1, 2] + elif keep == "last": + data = [5, 6, 7] + time = [0, 1, 2] + else: + data = [6, 7] + time = [1, 2] + + expected = xr.DataArray(data, dims="time", coords={"time": time}, name="test") + result = ds.drop_duplicates("time", keep=keep) + assert_equal(expected, result) + + +class TestNumpyCoercion: + # TODO once flexible indexes refactor complete also test coercion of dimension coords + def test_from_numpy(self): + da = xr.DataArray([1, 2, 3], dims="x", coords={"lat": ("x", [4, 5, 6])}) + + assert_identical(da.as_numpy(), da) + np.testing.assert_equal(da.to_numpy(), np.array([1, 2, 3])) + np.testing.assert_equal(da["lat"].to_numpy(), np.array([4, 5, 6])) + + @requires_dask + def test_from_dask(self): + da = xr.DataArray([1, 2, 3], dims="x", coords={"lat": ("x", [4, 5, 6])}) + da_chunked = da.chunk(1) + + assert_identical(da_chunked.as_numpy(), da.compute()) + np.testing.assert_equal(da.to_numpy(), np.array([1, 2, 3])) + np.testing.assert_equal(da["lat"].to_numpy(), np.array([4, 5, 6])) + + @requires_pint_0_15 + def test_from_pint(self): + from pint import Quantity + + arr = np.array([1, 2, 3]) + da = xr.DataArray( + Quantity(arr, units="Pa"), + dims="x", + coords={"lat": ("x", Quantity(arr + 3, units="m"))}, + ) + + expected = xr.DataArray(arr, dims="x", coords={"lat": ("x", arr + 3)}) + assert_identical(da.as_numpy(), expected) + np.testing.assert_equal(da.to_numpy(), arr) + np.testing.assert_equal(da["lat"].to_numpy(), arr + 3) + + @requires_sparse + def test_from_sparse(self): + import sparse + + arr = np.diagflat([1, 2, 3]) + sparr = sparse.COO.from_numpy(arr) + da = xr.DataArray( + sparr, dims=["x", "y"], coords={"elev": (("x", "y"), sparr + 3)} + ) + + expected = xr.DataArray( + arr, dims=["x", "y"], coords={"elev": (("x", "y"), arr + 3)} + ) + assert_identical(da.as_numpy(), expected) + np.testing.assert_equal(da.to_numpy(), arr) + + @requires_cupy + def test_from_cupy(self): + import cupy as cp + + arr = np.array([1, 2, 3]) + da = xr.DataArray( + cp.array(arr), dims="x", coords={"lat": ("x", cp.array(arr + 3))} + ) + + expected = xr.DataArray(arr, dims="x", coords={"lat": ("x", arr + 3)}) + assert_identical(da.as_numpy(), expected) + np.testing.assert_equal(da.to_numpy(), arr) + + @requires_dask + @requires_pint_0_15 + def test_from_pint_wrapping_dask(self): + import dask + from pint import Quantity + + arr = np.array([1, 2, 3]) + d = dask.array.from_array(arr) + da = xr.DataArray( + Quantity(d, units="Pa"), + dims="x", + coords={"lat": ("x", Quantity(d, units="m") * 2)}, + ) + + result = da.as_numpy() + result.name = None # remove dask-assigned name + expected = xr.DataArray(arr, dims="x", coords={"lat": ("x", arr * 2)}) + assert_identical(result, expected) + np.testing.assert_equal(da.to_numpy(), arr) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index db47faa8d2b..9b8b7c748f1 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -8,7 +8,9 @@ import numpy as np import pandas as pd import pytest +from pandas.core.computation.ops import UndefinedVariableError from pandas.core.indexes.datetimes import DatetimeIndex +from pandas.tseries.frequencies import to_offset import xarray as xr from xarray import ( @@ -26,7 +28,7 @@ from xarray.coding.cftimeindex import CFTimeIndex from xarray.core import dtypes, indexing, utils from xarray.core.common import duck_array_ops, full_like -from xarray.core.npcompat import IS_NEP18_ACTIVE +from xarray.core.indexes import Index from xarray.core.pycompat import integer_types from xarray.core.utils import is_scalar @@ -37,13 +39,16 @@ assert_array_equal, assert_equal, assert_identical, + create_test_data, has_cftime, has_dask, - raises_regex, requires_bottleneck, requires_cftime, + requires_cupy, requires_dask, requires_numbagg, + requires_numexpr, + requires_pint_0_15, requires_scipy, requires_sparse, source_ndarray, @@ -60,31 +65,6 @@ ] -def create_test_data(seed=None): - rs = np.random.RandomState(seed) - _vars = { - "var1": ["dim1", "dim2"], - "var2": ["dim1", "dim2"], - "var3": ["dim3", "dim1"], - } - _dims = {"dim1": 8, "dim2": 9, "dim3": 10} - - obj = Dataset() - obj["time"] = ("time", pd.date_range("2000-01-01", periods=20)) - obj["dim2"] = ("dim2", 0.5 * np.arange(_dims["dim2"])) - obj["dim3"] = ("dim3", list("abcdefghij")) - for v, dims in sorted(_vars.items()): - data = rs.normal(size=tuple(_dims[d] for d in dims)) - obj[v] = (dims, data, {"foo": "variable"}) - obj.coords["numbers"] = ( - "dim3", - np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64"), - ) - obj.encoding = {"foo": "bar"} - assert all(obj.data.flags.writeable for obj in obj.variables.values()) - return obj - - def create_append_test_data(seed=None): rs = np.random.RandomState(seed) @@ -187,7 +167,7 @@ def get_variables(self): def lazy_inaccessible(k, v): if k in self._indexvars: return v - data = indexing.LazilyOuterIndexedArray(InaccessibleArray(v.values)) + data = indexing.LazilyIndexedArray(InaccessibleArray(v.values)) return Variable(v.dims, data, v.attrs) return {k: lazy_inaccessible(k, v) for k, v in self._variables.items()} @@ -201,11 +181,11 @@ def test_repr(self): expected = dedent( """\ - Dimensions: (dim1: 8, dim2: 9, dim3: 10, time: 20) + Dimensions: (dim2: 9, dim3: 10, time: 20, dim1: 8) Coordinates: - * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-20 * dim2 (dim2) float64 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 * dim3 (dim3) %s 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' + * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-20 numbers (dim3) int64 0 1 2 0 0 1 1 2 2 3 Dimensions without coordinates: dim1 Data variables: @@ -317,7 +297,6 @@ def test_unicode_data(self): actual = str(data) assert expected == actual - @pytest.mark.skipif(not IS_NEP18_ACTIVE, reason="requires __array_function__") def test_repr_nep18(self): class Array: def __init__(self): @@ -354,14 +333,14 @@ def test_info(self): """\ xarray.Dataset { dimensions: - \tdim1 = 8 ; \tdim2 = 9 ; - \tdim3 = 10 ; \ttime = 20 ; + \tdim1 = 8 ; + \tdim3 = 10 ; variables: - \tdatetime64[ns] time(time) ; \tfloat64 dim2(dim2) ; + \tdatetime64[ns] time(time) ; \tfloat64 var1(dim1, dim2) ; \t\tvar1:foo = variable ; \tfloat64 var2(dim1, dim2) ; @@ -384,13 +363,13 @@ def test_constructor(self): x2 = ("x", np.arange(1000)) z = (["x", "y"], np.arange(1000).reshape(100, 10)) - with raises_regex(ValueError, "conflicting sizes"): + with pytest.raises(ValueError, match=r"conflicting sizes"): Dataset({"a": x1, "b": x2}) - with raises_regex(ValueError, "disallows such variables"): + with pytest.raises(ValueError, match=r"disallows such variables"): Dataset({"a": x1, "x": z}) - with raises_regex(TypeError, "tuple of form"): + with pytest.raises(TypeError, match=r"tuple of form"): Dataset({"x": (1, 2, 3, 4, 5, 6, 7)}) - with raises_regex(ValueError, "already exists as a scalar"): + with pytest.raises(ValueError, match=r"already exists as a scalar"): Dataset({"x": 0, "y": ("x", [1, 2, 3])}) # verify handling of DataArrays @@ -442,10 +421,6 @@ class Arbitrary: actual = Dataset({"x": arg}) assert_identical(expected, actual) - def test_constructor_deprecated(self): - with raises_regex(ValueError, "DataArray dimensions"): - DataArray([1, 2, 3], coords={"x": [0, 1, 2]}) - def test_constructor_auto_align(self): a = DataArray([1, 2], [("x", [0, 1])]) b = DataArray([3, 4], [("x", [1, 2])]) @@ -473,7 +448,7 @@ def test_constructor_auto_align(self): assert_identical(expected3, actual) e = ("x", [0, 0]) - with raises_regex(ValueError, "conflicting sizes"): + with pytest.raises(ValueError, match=r"conflicting sizes"): Dataset({"a": a, "b": b, "e": e}) def test_constructor_pandas_sequence(self): @@ -540,7 +515,7 @@ def test_constructor_compat(self): assert_identical(expected, actual) def test_constructor_with_coords(self): - with raises_regex(ValueError, "found in both data_vars and"): + with pytest.raises(ValueError, match=r"found in both data_vars and"): Dataset({"a": ("x", [1])}, {"a": ("x", [1])}) ds = Dataset({}, {"a": ("x", [1])}) @@ -550,21 +525,20 @@ def test_constructor_with_coords(self): mindex = pd.MultiIndex.from_product( [["a", "b"], [1, 2]], names=("level_1", "level_2") ) - with raises_regex(ValueError, "conflicting MultiIndex"): + with pytest.raises(ValueError, match=r"conflicting MultiIndex"): Dataset({}, {"x": mindex, "y": mindex}) Dataset({}, {"x": mindex, "level_1": range(4)}) def test_properties(self): ds = create_test_data() assert ds.dims == {"dim1": 8, "dim2": 9, "dim3": 10, "time": 20} - assert list(ds.dims) == sorted(ds.dims) assert ds.sizes == ds.dims # These exact types aren't public API, but this makes sure we don't # change them inadvertently: assert isinstance(ds.dims, utils.Frozen) - assert isinstance(ds.dims.mapping, utils.SortedKeysDict) - assert type(ds.dims.mapping.mapping) is dict + assert isinstance(ds.dims.mapping, dict) + assert type(ds.dims.mapping) is dict assert list(ds) == list(ds.data_vars) assert list(ds.keys()) == list(ds.data_vars) @@ -580,11 +554,17 @@ def test_properties(self): assert "numbers" not in ds.data_vars assert len(ds.data_vars) == 3 + assert set(ds.xindexes) == {"dim2", "dim3", "time"} + assert len(ds.xindexes) == 3 + assert "dim2" in repr(ds.xindexes) + assert all([isinstance(idx, Index) for idx in ds.xindexes.values()]) + assert set(ds.indexes) == {"dim2", "dim3", "time"} assert len(ds.indexes) == 3 assert "dim2" in repr(ds.indexes) + assert all([isinstance(idx, pd.Index) for idx in ds.indexes.values()]) - assert list(ds.coords) == ["time", "dim2", "dim3", "numbers"] + assert list(ds.coords) == ["dim2", "dim3", "time", "numbers"] assert "dim2" in ds.coords assert "numbers" in ds.coords assert "var1" not in ds.coords @@ -595,7 +575,7 @@ def test_properties(self): def test_asarray(self): ds = Dataset({"x": 0}) - with raises_regex(TypeError, "cannot directly convert"): + with pytest.raises(TypeError, match=r"cannot directly convert"): np.asarray(ds) def test_get_index(self): @@ -723,7 +703,7 @@ def test_coords_modify(self): assert_array_equal(actual["z"], ["a", "b"]) actual = data.copy(deep=True) - with raises_regex(ValueError, "conflicting sizes"): + with pytest.raises(ValueError, match=r"conflicting sizes"): actual.coords["x"] = ("x", [-1]) assert_identical(actual, data) # should not be modified @@ -745,12 +725,12 @@ def test_coords_modify(self): # regression test for GH3746 del actual.coords["x"] - assert "x" not in actual.indexes + assert "x" not in actual.xindexes def test_update_index(self): actual = Dataset(coords={"x": [1, 2, 3]}) actual["x"] = ["a", "b", "c"] - assert actual.indexes["x"].equals(pd.Index(["a", "b", "c"])) + assert actual.xindexes["x"].to_pandas_index().equals(pd.Index(["a", "b", "c"])) def test_coords_setitem_with_new_dimension(self): actual = Dataset() @@ -760,7 +740,7 @@ def test_coords_setitem_with_new_dimension(self): def test_coords_setitem_multiindex(self): data = create_test_multiindex() - with raises_regex(ValueError, "conflicting MultiIndex"): + with pytest.raises(ValueError, match=r"conflicting MultiIndex"): data.coords["level_1"] = range(4) def test_coords_set(self): @@ -793,7 +773,7 @@ def test_coords_set(self): actual = all_coords.reset_coords("zzz") assert_identical(two_coords, actual) - with raises_regex(ValueError, "cannot remove index"): + with pytest.raises(ValueError, match=r"cannot remove index"): one_coord.reset_coords("x") actual = all_coords.reset_coords("zzz", drop=True) @@ -969,7 +949,7 @@ def get_dask_names(ds): for k, v in new_dask_names.items(): assert v == orig_dask_names[k] - with raises_regex(ValueError, "some chunks"): + with pytest.raises(ValueError, match=r"some chunks"): data.chunk({"foo": 10}) @requires_dask @@ -1022,18 +1002,18 @@ def test_isel(self): with pytest.raises(ValueError): data.isel(not_a_dim=slice(0, 2)) - with raises_regex( + with pytest.raises( ValueError, - r"Dimensions {'not_a_dim'} do not exist. Expected " + match=r"Dimensions {'not_a_dim'} do not exist. Expected " r"one or more of " - r"[\w\W]*'time'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*", + r"[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*'time'[\w\W]*'dim\d'[\w\W]*", ): data.isel(not_a_dim=slice(0, 2)) with pytest.warns( UserWarning, match=r"Dimensions {'not_a_dim'} do not exist. " r"Expected one or more of " - r"[\w\W]*'time'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*", + r"[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*'time'[\w\W]*'dim\d'[\w\W]*", ): data.isel(not_a_dim=slice(0, 2), missing_dims="warn") assert_identical(data, data.isel(not_a_dim=slice(0, 2), missing_dims="ignore")) @@ -1042,19 +1022,19 @@ def test_isel(self): assert {"time": 20, "dim2": 9, "dim3": 10} == ret.dims assert set(data.data_vars) == set(ret.data_vars) assert set(data.coords) == set(ret.coords) - assert set(data.indexes) == set(ret.indexes) + assert set(data.xindexes) == set(ret.xindexes) ret = data.isel(time=slice(2), dim1=0, dim2=slice(5)) assert {"time": 2, "dim2": 5, "dim3": 10} == ret.dims assert set(data.data_vars) == set(ret.data_vars) assert set(data.coords) == set(ret.coords) - assert set(data.indexes) == set(ret.indexes) + assert set(data.xindexes) == set(ret.xindexes) ret = data.isel(time=0, dim1=0, dim2=slice(5)) assert {"dim2": 5, "dim3": 10} == ret.dims assert set(data.data_vars) == set(ret.data_vars) assert set(data.coords) == set(ret.coords) - assert set(data.indexes) == set(list(ret.indexes) + ["time"]) + assert set(data.xindexes) == set(list(ret.xindexes) + ["time"]) def test_isel_fancy(self): # isel with fancy indexing. @@ -1131,9 +1111,9 @@ def test_isel_fancy(self): data.isel(dim2=(("points",), pdim2), dim1=(("points",), pdim1)), ) # make sure we're raising errors in the right places - with raises_regex(IndexError, "Dimensions of indexers mismatch"): + with pytest.raises(IndexError, match=r"Dimensions of indexers mismatch"): data.isel(dim1=(("points",), [1, 2]), dim2=(("points",), [1, 2, 3])) - with raises_regex(TypeError, "cannot use a Dataset"): + with pytest.raises(TypeError, match=r"cannot use a Dataset"): data.isel(dim1=Dataset({"points": [1, 2]})) # test to be sure we keep around variables that were not indexed @@ -1152,7 +1132,7 @@ def test_isel_fancy(self): assert "station" in actual.dims assert_identical(actual["station"].drop_vars(["dim2"]), stations["station"]) - with raises_regex(ValueError, "conflicting values for "): + with pytest.raises(ValueError, match=r"conflicting values for "): data.isel( dim1=DataArray( [0, 1, 2], dims="station", coords={"station": [0, 1, 2]} @@ -1191,7 +1171,7 @@ def test_isel_fancy(self): assert_array_equal(actual["var3"], expected_var3) def test_isel_dataarray(self): - """ Test for indexing by DataArray """ + """Test for indexing by DataArray""" data = create_test_data() # indexing with DataArray with same-name coordinates. indexing_da = DataArray( @@ -1205,12 +1185,12 @@ def test_isel_dataarray(self): indexing_da = DataArray( np.arange(1, 4), dims=["dim2"], coords={"dim2": np.random.randn(3)} ) - with raises_regex(IndexError, "dimension coordinate 'dim2'"): + with pytest.raises(IndexError, match=r"dimension coordinate 'dim2'"): actual = data.isel(dim2=indexing_da) # Also the case for DataArray - with raises_regex(IndexError, "dimension coordinate 'dim2'"): + with pytest.raises(IndexError, match=r"dimension coordinate 'dim2'"): actual = data["var2"].isel(dim2=indexing_da) - with raises_regex(IndexError, "dimension coordinate 'dim2'"): + with pytest.raises(IndexError, match=r"dimension coordinate 'dim2'"): data["dim2"].isel(dim2=indexing_da) # same name coordinate which does not conflict @@ -1277,7 +1257,7 @@ def test_isel_dataarray(self): # indexer generated from coordinates indexing_ds = Dataset({}, coords={"dim2": [0, 1, 2]}) - with raises_regex(IndexError, "dimension coordinate 'dim2'"): + with pytest.raises(IndexError, match=r"dimension coordinate 'dim2'"): actual = data.isel(dim2=indexing_ds["dim2"]) def test_sel(self): @@ -1390,13 +1370,13 @@ def test_sel_dataarray_mindex(self): ) actual_isel = mds.isel(x=xr.DataArray(np.arange(3), dims="x")) - actual_sel = mds.sel(x=DataArray(mds.indexes["x"][:3], dims="x")) + actual_sel = mds.sel(x=DataArray(midx[:3], dims="x")) assert actual_isel["x"].dims == ("x",) assert actual_sel["x"].dims == ("x",) assert_identical(actual_isel, actual_sel) actual_isel = mds.isel(x=xr.DataArray(np.arange(3), dims="z")) - actual_sel = mds.sel(x=Variable("z", mds.indexes["x"][:3])) + actual_sel = mds.sel(x=Variable("z", midx[:3])) assert actual_isel["x"].dims == ("z",) assert actual_sel["x"].dims == ("z",) assert_identical(actual_isel, actual_sel) @@ -1406,19 +1386,19 @@ def test_sel_dataarray_mindex(self): x=xr.DataArray(np.arange(3), dims="z", coords={"z": [0, 1, 2]}) ) actual_sel = mds.sel( - x=xr.DataArray(mds.indexes["x"][:3], dims="z", coords={"z": [0, 1, 2]}) + x=xr.DataArray(midx[:3], dims="z", coords={"z": [0, 1, 2]}) ) assert actual_isel["x"].dims == ("z",) assert actual_sel["x"].dims == ("z",) assert_identical(actual_isel, actual_sel) # Vectorized indexing with level-variables raises an error - with raises_regex(ValueError, "Vectorized selection is "): + with pytest.raises(ValueError, match=r"Vectorized selection is "): mds.sel(one=["a", "b"]) - with raises_regex( + with pytest.raises( ValueError, - "Vectorized selection is not available along MultiIndex variable: x", + match=r"Vectorized selection is not available along coordinate 'x' with a multi-index", ): mds.sel( x=xr.DataArray( @@ -1531,11 +1511,11 @@ def test_head(self): actual = data.head() assert_equal(expected, actual) - with raises_regex(TypeError, "either dict-like or a single int"): + with pytest.raises(TypeError, match=r"either dict-like or a single int"): data.head([3]) - with raises_regex(TypeError, "expected integer type"): + with pytest.raises(TypeError, match=r"expected integer type"): data.head(dim2=3.1) - with raises_regex(ValueError, "expected positive int"): + with pytest.raises(ValueError, match=r"expected positive int"): data.head(time=-3) def test_tail(self): @@ -1557,11 +1537,11 @@ def test_tail(self): actual = data.tail() assert_equal(expected, actual) - with raises_regex(TypeError, "either dict-like or a single int"): + with pytest.raises(TypeError, match=r"either dict-like or a single int"): data.tail([3]) - with raises_regex(TypeError, "expected integer type"): + with pytest.raises(TypeError, match=r"expected integer type"): data.tail(dim2=3.1) - with raises_regex(ValueError, "expected positive int"): + with pytest.raises(ValueError, match=r"expected positive int"): data.tail(time=-3) def test_thin(self): @@ -1575,13 +1555,13 @@ def test_thin(self): actual = data.thin(6) assert_equal(expected, actual) - with raises_regex(TypeError, "either dict-like or a single int"): + with pytest.raises(TypeError, match=r"either dict-like or a single int"): data.thin([3]) - with raises_regex(TypeError, "expected integer type"): + with pytest.raises(TypeError, match=r"expected integer type"): data.thin(dim2=3.1) - with raises_regex(ValueError, "cannot be zero"): + with pytest.raises(ValueError, match=r"cannot be zero"): data.thin(time=0) - with raises_regex(ValueError, "expected positive int"): + with pytest.raises(ValueError, match=r"expected positive int"): data.thin(time=-3) @pytest.mark.filterwarnings("ignore::DeprecationWarning") @@ -1695,15 +1675,15 @@ def test_sel_method(self): actual = data.sel(dim2=[1.45], method="backfill") assert_identical(expected, actual) - with raises_regex(NotImplementedError, "slice objects"): + with pytest.raises(NotImplementedError, match=r"slice objects"): data.sel(dim2=slice(1, 3), method="ffill") - with raises_regex(TypeError, "``method``"): + with pytest.raises(TypeError, match=r"``method``"): # this should not pass silently data.sel(method=data) # cannot pass method if there is no associated coordinate - with raises_regex(ValueError, "cannot supply"): + with pytest.raises(ValueError, match=r"cannot supply"): data.sel(dim1=0, method="nearest") def test_loc(self): @@ -1711,10 +1691,8 @@ def test_loc(self): expected = data.sel(dim3="a") actual = data.loc[dict(dim3="a")] assert_identical(expected, actual) - with raises_regex(TypeError, "can only lookup dict"): + with pytest.raises(TypeError, match=r"can only lookup dict"): data.loc["a"] - with pytest.raises(TypeError): - data.loc[dict(dim3="a")] = 0 def test_selection_multiindex(self): mindex = pd.MultiIndex.from_product( @@ -1768,6 +1746,27 @@ def test_broadcast_like(self): assert_identical(original2.broadcast_like(original1), expected2) + def test_to_pandas(self): + # 0D -> series + actual = Dataset({"a": 1, "b": 2}).to_pandas() + expected = pd.Series([1, 2], ["a", "b"]) + assert_array_equal(actual, expected) + + # 1D -> dataframe + x = np.random.randn(10) + y = np.random.randn(10) + t = list("abcdefghij") + ds = Dataset({"a": ("t", x), "b": ("t", y), "t": ("t", t)}) + actual = ds.to_pandas() + expected = ds.to_dataframe() + assert expected.equals(actual), (expected, actual) + + # 2D -> error + x2d = np.random.randn(10, 10) + y2d = np.random.randn(10, 10) + with pytest.raises(ValueError, match=r"cannot convert Datasets"): + Dataset({"a": (["t", "r"], x2d), "b": (["t", "r"], y2d)}).to_pandas() + def test_reindex_like(self): data = create_test_data() data["letters"] = ("dim3", 10 * ["a"]) @@ -1802,7 +1801,9 @@ def test_reindex(self): actual = data.reindex(dim1=data["dim1"].to_index()) assert_identical(actual, expected) - with raises_regex(ValueError, "cannot reindex or align along dimension"): + with pytest.raises( + ValueError, match=r"cannot reindex or align along dimension" + ): data.reindex(dim1=data["dim1"][:5]) expected = data.isel(dim2=slice(5)) @@ -1813,13 +1814,13 @@ def test_reindex(self): actual = data.reindex({"dim2": data["dim2"]}) expected = data assert_identical(actual, expected) - with raises_regex(ValueError, "cannot specify both"): + with pytest.raises(ValueError, match=r"cannot specify both"): data.reindex({"x": 0}, x=0) - with raises_regex(ValueError, "dictionary"): + with pytest.raises(ValueError, match=r"dictionary"): data.reindex("foo") # invalid dimension - with raises_regex(ValueError, "invalid reindex dim"): + with pytest.raises(ValueError, match=r"invalid reindex dim"): data.reindex(invalid=0) # out of order @@ -2031,7 +2032,7 @@ def test_align(self): assert np.isnan(left2["var3"][-2:]).all() - with raises_regex(ValueError, "invalid value for join"): + with pytest.raises(ValueError, match=r"invalid value for join"): align(left, right, join="foobar") with pytest.raises(TypeError): align(left, right, foo="bar") @@ -2044,7 +2045,7 @@ def test_align_exact(self): assert_identical(left1, left) assert_identical(left2, left) - with raises_regex(ValueError, "indexes .* not equal"): + with pytest.raises(ValueError, match=r"indexes .* not equal"): xr.align(left, right, join="exact") def test_align_override(self): @@ -2066,7 +2067,7 @@ def test_align_override(self): assert_identical(left.isel(x=0, drop=True), new_left) assert_identical(right, new_right) - with raises_regex(ValueError, "Indexes along dimension 'x' don't have"): + with pytest.raises(ValueError, match=r"Indexes along dimension 'x' don't have"): xr.align(left.isel(x=0).expand_dims("x"), right, join="override") def test_align_exclude(self): @@ -2141,7 +2142,7 @@ def test_align_non_unique(self): assert_identical(x2, x) y = Dataset({"bar": ("x", [6, 7]), "x": [0, 1]}) - with raises_regex(ValueError, "cannot reindex or align"): + with pytest.raises(ValueError, match=r"cannot reindex or align"): align(x, y) def test_align_str_dtype(self): @@ -2301,7 +2302,7 @@ def test_drop_variables(self): actual = data.drop_vars(["time"]) assert_identical(expected, actual) - with raises_regex(ValueError, "cannot be found"): + with pytest.raises(ValueError, match=r"cannot be found"): data.drop_vars("not_found_here") actual = data.drop_vars("not_found_here", errors="ignore") @@ -2375,7 +2376,7 @@ def test_drop_index_labels(self): expected = data.isel(y=[0, 2]) assert_identical(expected, actual) - with raises_regex(KeyError, "not found in axis"): + with pytest.raises(KeyError, match=r"not found in axis"): data.drop_sel(x=0) def test_drop_labels_by_keyword(self): @@ -2396,7 +2397,7 @@ def test_drop_labels_by_keyword(self): with pytest.warns(FutureWarning): data.drop(arr.coords) with pytest.warns(FutureWarning): - data.drop(arr.indexes) + data.drop(arr.xindexes) assert_array_equal(ds1.coords["x"], ["b"]) assert_array_equal(ds2.coords["x"], ["b"]) @@ -2569,11 +2570,11 @@ def test_copy_coords(self, deep, expected_orig): def test_copy_with_data_errors(self): orig = create_test_data() new_var1 = np.arange(orig["var1"].size).reshape(orig["var1"].shape) - with raises_regex(ValueError, "Data must be dict-like"): + with pytest.raises(ValueError, match=r"Data must be dict-like"): orig.copy(data=new_var1) - with raises_regex(ValueError, "only contain variables in original"): + with pytest.raises(ValueError, match=r"only contain variables in original"): orig.copy(data={"not_in_original": new_var1}) - with raises_regex(ValueError, "contain all variables in original"): + with pytest.raises(ValueError, match=r"contain all variables in original"): orig.copy(data={"var1": new_var1}) def test_rename(self): @@ -2601,10 +2602,10 @@ def test_rename(self): assert "var1" not in renamed assert "dim2" not in renamed - with raises_regex(ValueError, "cannot rename 'not_a_var'"): + with pytest.raises(ValueError, match=r"cannot rename 'not_a_var'"): data.rename({"not_a_var": "nada"}) - with raises_regex(ValueError, "'var1' conflicts"): + with pytest.raises(ValueError, match=r"'var1' conflicts"): data.rename({"var2": "var1"}) # verify that we can rename a variable without accessing the data @@ -2621,7 +2622,7 @@ def test_rename_old_name(self): # regtest for GH1477 data = create_test_data() - with raises_regex(ValueError, "'samecol' conflicts"): + with pytest.raises(ValueError, match=r"'samecol' conflicts"): data.rename({"var1": "samecol", "var2": "samecol"}) # This shouldn't cause any problems. @@ -2675,7 +2676,7 @@ def test_rename_multiindex(self): [([1, 2]), ([3, 4])], names=["level0", "level1"] ) data = Dataset({}, {"x": mindex}) - with raises_regex(ValueError, "conflicting MultiIndex"): + with pytest.raises(ValueError, match=r"conflicting MultiIndex"): data.rename({"x": "level0"}) @requires_cftime @@ -2686,21 +2687,23 @@ def test_rename_does_not_change_CFTimeIndex_type(self): orig = Dataset(coords={"time": time}) renamed = orig.rename(time="time_new") - assert "time_new" in renamed.indexes - assert isinstance(renamed.indexes["time_new"], CFTimeIndex) - assert renamed.indexes["time_new"].name == "time_new" + assert "time_new" in renamed.xindexes + # TODO: benbovy - flexible indexes: update when CFTimeIndex + # inherits from xarray.Index + assert isinstance(renamed.xindexes["time_new"].to_pandas_index(), CFTimeIndex) + assert renamed.xindexes["time_new"].to_pandas_index().name == "time_new" # check original has not changed - assert "time" in orig.indexes - assert isinstance(orig.indexes["time"], CFTimeIndex) - assert orig.indexes["time"].name == "time" + assert "time" in orig.xindexes + assert isinstance(orig.xindexes["time"].to_pandas_index(), CFTimeIndex) + assert orig.xindexes["time"].to_pandas_index().name == "time" # note: rename_dims(time="time_new") drops "ds.indexes" renamed = orig.rename_dims() - assert isinstance(renamed.indexes["time"], CFTimeIndex) + assert isinstance(renamed.xindexes["time"].to_pandas_index(), CFTimeIndex) renamed = orig.rename_vars() - assert isinstance(renamed.indexes["time"], CFTimeIndex) + assert isinstance(renamed.xindexes["time"].to_pandas_index(), CFTimeIndex) def test_rename_does_not_change_DatetimeIndex_type(self): # make sure DatetimeIndex is conderved on rename @@ -2709,21 +2712,23 @@ def test_rename_does_not_change_DatetimeIndex_type(self): orig = Dataset(coords={"time": time}) renamed = orig.rename(time="time_new") - assert "time_new" in renamed.indexes - assert isinstance(renamed.indexes["time_new"], DatetimeIndex) - assert renamed.indexes["time_new"].name == "time_new" + assert "time_new" in renamed.xindexes + # TODO: benbovy - flexible indexes: update when DatetimeIndex + # inherits from xarray.Index? + assert isinstance(renamed.xindexes["time_new"].to_pandas_index(), DatetimeIndex) + assert renamed.xindexes["time_new"].to_pandas_index().name == "time_new" # check original has not changed - assert "time" in orig.indexes - assert isinstance(orig.indexes["time"], DatetimeIndex) - assert orig.indexes["time"].name == "time" + assert "time" in orig.xindexes + assert isinstance(orig.xindexes["time"].to_pandas_index(), DatetimeIndex) + assert orig.xindexes["time"].to_pandas_index().name == "time" # note: rename_dims(time="time_new") drops "ds.indexes" renamed = orig.rename_dims() - assert isinstance(renamed.indexes["time"], DatetimeIndex) + assert isinstance(renamed.xindexes["time"].to_pandas_index(), DatetimeIndex) renamed = orig.rename_vars() - assert isinstance(renamed.indexes["time"], DatetimeIndex) + assert isinstance(renamed.xindexes["time"].to_pandas_index(), DatetimeIndex) def test_swap_dims(self): original = Dataset({"x": [1, 2, 3], "y": ("x", list("abc")), "z": 42}) @@ -2732,14 +2737,17 @@ def test_swap_dims(self): assert_identical(expected, actual) assert isinstance(actual.variables["y"], IndexVariable) assert isinstance(actual.variables["x"], Variable) - pd.testing.assert_index_equal(actual.indexes["y"], expected.indexes["y"]) + pd.testing.assert_index_equal( + actual.xindexes["y"].to_pandas_index(), + expected.xindexes["y"].to_pandas_index(), + ) roundtripped = actual.swap_dims({"y": "x"}) assert_identical(original.set_coords("y"), roundtripped) - with raises_regex(ValueError, "cannot swap"): + with pytest.raises(ValueError, match=r"cannot swap"): original.swap_dims({"y": "x"}) - with raises_regex(ValueError, "replacement dimension"): + with pytest.raises(ValueError, match=r"replacement dimension"): original.swap_dims({"x": "z"}) expected = Dataset( @@ -2763,7 +2771,10 @@ def test_swap_dims(self): assert_identical(expected, actual) assert isinstance(actual.variables["y"], IndexVariable) assert isinstance(actual.variables["x"], Variable) - pd.testing.assert_index_equal(actual.indexes["y"], expected.indexes["y"]) + pd.testing.assert_index_equal( + actual.xindexes["y"].to_pandas_index(), + expected.xindexes["y"].to_pandas_index(), + ) def test_expand_dims_error(self): original = Dataset( @@ -2780,13 +2791,13 @@ def test_expand_dims_error(self): attrs={"key": "entry"}, ) - with raises_regex(ValueError, "already exists"): + with pytest.raises(ValueError, match=r"already exists"): original.expand_dims(dim=["x"]) # Make sure it raises true error also for non-dimensional coordinates # which has dimension. original = original.set_coords("z") - with raises_regex(ValueError, "already exists"): + with pytest.raises(ValueError, match=r"already exists"): original.expand_dims(dim=["z"]) original = Dataset( @@ -2802,9 +2813,9 @@ def test_expand_dims_error(self): }, attrs={"key": "entry"}, ) - with raises_regex(TypeError, "value of new dimension"): + with pytest.raises(TypeError, match=r"value of new dimension"): original.expand_dims({"d": 3.2}) - with raises_regex(ValueError, "both keyword and positional"): + with pytest.raises(ValueError, match=r"both keyword and positional"): original.expand_dims({"d": 4}, e=4) def test_expand_dims_int(self): @@ -2993,7 +3004,7 @@ def test_reorder_levels(self): assert_identical(reindexed, expected) ds = Dataset({}, coords={"x": [1, 2]}) - with raises_regex(ValueError, "has no MultiIndex"): + with pytest.raises(ValueError, match=r"has no MultiIndex"): ds.reorder_levels(x=["level_1", "level_2"]) def test_stack(self): @@ -3038,9 +3049,9 @@ def test_unstack(self): def test_unstack_errors(self): ds = Dataset({"x": [1, 2, 3]}) - with raises_regex(ValueError, "does not contain the dimensions"): + with pytest.raises(ValueError, match=r"does not contain the dimensions"): ds.unstack("foo") - with raises_regex(ValueError, "do not have a MultiIndex"): + with pytest.raises(ValueError, match=r"do not have a MultiIndex"): ds.unstack("x") def test_unstack_fill_value(self): @@ -3140,7 +3151,9 @@ def test_to_stacked_array_dtype_dims(self): D = xr.Dataset({"a": a, "b": b}) sample_dims = ["x"] y = D.to_stacked_array("features", sample_dims) - assert y.indexes["features"].levels[1].dtype == D.y.dtype + # TODO: benbovy - flexible indexes: update when MultiIndex has its own class + # inherited from xarray.Index + assert y.xindexes["features"].to_pandas_index().levels[1].dtype == D.y.dtype assert y.dims == ("x", "features") def test_to_stacked_array_to_unstacked_dataset(self): @@ -3180,13 +3193,13 @@ def test_update(self): data = create_test_data(seed=0) expected = data.copy() var2 = Variable("dim1", np.arange(8)) - actual = data.update({"var2": var2}) + actual = data + actual.update({"var2": var2}) expected["var2"] = var2 assert_identical(expected, actual) actual = data.copy() - actual_result = actual.update(data) - assert actual_result is actual + actual.update(data) assert_identical(expected, actual) other = Dataset(attrs={"new": "attr"}) @@ -3216,7 +3229,7 @@ def test_update_auto_align(self): expected = Dataset({"x": ("t", [3, 4]), "y": ("t", [np.nan, 5])}, {"t": [0, 1]}) actual = ds.copy() other = {"y": ("t", [5]), "t": [1]} - with raises_regex(ValueError, "conflicting sizes"): + with pytest.raises(ValueError, match=r"conflicting sizes"): actual.update(other) actual.update(Dataset(other)) assert_identical(expected, actual) @@ -3261,9 +3274,14 @@ def test_getitem_hashable(self): expected = data["var1"] + 1 expected.name = (3, 4) assert_identical(expected, data[(3, 4)]) - with raises_regex(KeyError, "('var1', 'var2')"): + with pytest.raises(KeyError, match=r"('var1', 'var2')"): data[("var1", "var2")] + def test_getitem_multiple_dtype(self): + keys = ["foo", 1] + dataset = Dataset({key: ("dim0", range(1)) for key in keys}) + assert_identical(dataset, dataset[keys]) + def test_virtual_variables_default_coords(self): dataset = Dataset({"foo": ("x", range(10))}) expected = DataArray(range(10), dims="x", name="x") @@ -3359,7 +3377,7 @@ def test_setitem(self): data2["B"] = dv assert_identical(data1, data2) # can't assign an ND array without dimensions - with raises_regex(ValueError, "without explicit dimension names"): + with pytest.raises(ValueError, match=r"without explicit dimension names"): data2["C"] = var.values.reshape(2, 4) # but can assign a 1D array data1["C"] = var.values @@ -3370,17 +3388,69 @@ def test_setitem(self): data2["scalar"] = ([], 0) assert_identical(data1, data2) # can't use the same dimension name as a scalar var - with raises_regex(ValueError, "already exists as a scalar"): + with pytest.raises(ValueError, match=r"already exists as a scalar"): data1["newvar"] = ("scalar", [3, 4, 5]) # can't resize a used dimension - with raises_regex(ValueError, "arguments without labels"): + with pytest.raises(ValueError, match=r"arguments without labels"): data1["dim1"] = data1["dim1"][:5] # override an existing value data1["A"] = 3 * data2["A"] assert_equal(data1["A"], 3 * data2["A"]) - with pytest.raises(NotImplementedError): - data1[{"x": 0}] = 0 + # test assignment with positional and label-based indexing + data3 = data1[["var1", "var2"]] + data3["var3"] = data3.var1.isel(dim1=0) + data4 = data3.copy() + err_msg = ( + "can only set locations defined by dictionaries from Dataset.loc. Got: a" + ) + with pytest.raises(TypeError, match=err_msg): + data1.loc["a"] = 0 + err_msg = r"Variables \['A', 'B', 'scalar'\] in new values not available in original dataset:" + with pytest.raises(ValueError, match=err_msg): + data4[{"dim2": 1}] = data1[{"dim2": 2}] + err_msg = "Variable 'var3': indexer {'dim2': 0} not available" + with pytest.raises(ValueError, match=err_msg): + data1[{"dim2": 0}] = 0.0 + err_msg = "Variable 'var1': indexer {'dim2': 10} not available" + with pytest.raises(ValueError, match=err_msg): + data4[{"dim2": 10}] = data3[{"dim2": 2}] + err_msg = "Variable 'var1': dimension 'dim2' appears in new values" + with pytest.raises(KeyError, match=err_msg): + data4[{"dim2": 2}] = data3[{"dim2": [2]}] + err_msg = ( + "Variable 'var2': dimension order differs between original and new data" + ) + data3["var2"] = data3["var2"].T + with pytest.raises(ValueError, match=err_msg): + data4[{"dim2": [2, 3]}] = data3[{"dim2": [2, 3]}] + data3["var2"] = data3["var2"].T + err_msg = "indexes along dimension 'dim2' are not equal" + with pytest.raises(ValueError, match=err_msg): + data4[{"dim2": [2, 3]}] = data3[{"dim2": [2, 3, 4]}] + err_msg = "Dataset assignment only accepts DataArrays, Datasets, and scalars." + with pytest.raises(TypeError, match=err_msg): + data4[{"dim2": [2, 3]}] = data3["var1"][{"dim2": [3, 4]}].values + data5 = data4.astype(str) + data5["var4"] = data4["var1"] + err_msg = "could not convert string to float: 'a'" + with pytest.raises(ValueError, match=err_msg): + data5[{"dim2": 1}] = "a" + + data4[{"dim2": 0}] = 0.0 + data4[{"dim2": 1}] = data3[{"dim2": 2}] + data4.loc[{"dim2": 1.5}] = 1.0 + data4.loc[{"dim2": 2.0}] = data3.loc[{"dim2": 2.5}] + for v, dat3 in data3.items(): + dat4 = data4[v] + assert_array_equal(dat4[{"dim2": 0}], 0.0) + assert_array_equal(dat4[{"dim2": 1}], dat3[{"dim2": 2}]) + assert_array_equal(dat4.loc[{"dim2": 1.5}], 1.0) + assert_array_equal(dat4.loc[{"dim2": 2.0}], dat3.loc[{"dim2": 2.5}]) + unchanged = [1.0, 2.5, 3.0, 3.5, 4.0] + assert_identical( + dat4.loc[{"dim2": unchanged}], dat3.loc[{"dim2": unchanged}] + ) def test_setitem_pandas(self): @@ -3489,10 +3559,45 @@ def test_setitem_align_new_indexes(self): def test_setitem_str_dtype(self, dtype): ds = xr.Dataset(coords={"x": np.array(["x", "y"], dtype=dtype)}) + # test Dataset update ds["foo"] = xr.DataArray(np.array([0, 0]), dims=["x"]) assert np.issubdtype(ds.x.dtype, dtype) + def test_setitem_using_list(self): + + # assign a list of variables + var1 = Variable(["dim1"], np.random.randn(8)) + var2 = Variable(["dim1"], np.random.randn(8)) + actual = create_test_data() + expected = actual.copy() + expected["A"] = var1 + expected["B"] = var2 + actual[["A", "B"]] = [var1, var2] + assert_identical(actual, expected) + # assign a list of dataset arrays + dv = 2 * expected[["A", "B"]] + actual[["C", "D"]] = [d.variable for d in dv.data_vars.values()] + expected[["C", "D"]] = dv + assert_identical(actual, expected) + + @pytest.mark.parametrize( + "var_list, data, error_regex", + [ + ( + ["A", "B"], + [Variable(["dim1"], np.random.randn(8))], + r"Different lengths", + ), + ([], [Variable(["dim1"], np.random.randn(8))], r"Empty list of variables"), + (["A", "B"], xr.DataArray([1, 2]), r"assign single DataArray"), + ], + ) + def test_setitem_using_list_errors(self, var_list, data, error_regex): + actual = create_test_data() + with pytest.raises(ValueError, match=error_regex): + actual[var_list] = data + def test_assign(self): ds = Dataset() actual = ds.assign(x=[0, 1, 2], y=2) @@ -3549,7 +3654,7 @@ def test_assign_attrs(self): def test_assign_multiindex_level(self): data = create_test_multiindex() - with raises_regex(ValueError, "conflicting MultiIndex"): + with pytest.raises(ValueError, match=r"conflicting MultiIndex"): data.assign(level_1=range(4)) data.assign_coords(level_1=range(4)) # raise an Error when any level name is used as dimension GH:2299 @@ -3596,7 +3701,7 @@ def test_setitem_both_non_unique_index(self): def test_setitem_multiindex_level(self): data = create_test_multiindex() - with raises_regex(ValueError, "conflicting MultiIndex"): + with pytest.raises(ValueError, match=r"conflicting MultiIndex"): data["level_1"] = range(4) def test_delitem(self): @@ -3627,7 +3732,7 @@ def get_args(v): expected = expected.set_coords(data.coords) assert_identical(expected, data.squeeze(*args)) # invalid squeeze - with raises_regex(ValueError, "cannot select a dimension"): + with pytest.raises(ValueError, match=r"cannot select a dimension"): data.squeeze("y") def test_squeeze_drop(self): @@ -3653,173 +3758,6 @@ def test_squeeze_drop(self): selected = data.squeeze(drop=True) assert_identical(data, selected) - def test_groupby(self): - data = Dataset( - {"z": (["x", "y"], np.random.randn(3, 5))}, - {"x": ("x", list("abc")), "c": ("x", [0, 1, 0]), "y": range(5)}, - ) - groupby = data.groupby("x") - assert len(groupby) == 3 - expected_groups = {"a": 0, "b": 1, "c": 2} - assert groupby.groups == expected_groups - expected_items = [ - ("a", data.isel(x=0)), - ("b", data.isel(x=1)), - ("c", data.isel(x=2)), - ] - for actual, expected in zip(groupby, expected_items): - assert actual[0] == expected[0] - assert_equal(actual[1], expected[1]) - - def identity(x): - return x - - for k in ["x", "c", "y"]: - actual = data.groupby(k, squeeze=False).map(identity) - assert_equal(data, actual) - - def test_groupby_returns_new_type(self): - data = Dataset({"z": (["x", "y"], np.random.randn(3, 5))}) - - actual = data.groupby("x").map(lambda ds: ds["z"]) - expected = data["z"] - assert_identical(expected, actual) - - actual = data["z"].groupby("x").map(lambda x: x.to_dataset()) - expected = data - assert_identical(expected, actual) - - def test_groupby_iter(self): - data = create_test_data() - for n, (t, sub) in enumerate(list(data.groupby("dim1"))[:3]): - assert data["dim1"][n] == t - assert_equal(data["var1"][n], sub["var1"]) - assert_equal(data["var2"][n], sub["var2"]) - assert_equal(data["var3"][:, n], sub["var3"]) - - def test_groupby_errors(self): - data = create_test_data() - with raises_regex(TypeError, "`group` must be"): - data.groupby(np.arange(10)) - with raises_regex(ValueError, "length does not match"): - data.groupby(data["dim1"][:3]) - with raises_regex(TypeError, "`group` must be"): - data.groupby(data.coords["dim1"].to_index()) - - def test_groupby_reduce(self): - data = Dataset( - { - "xy": (["x", "y"], np.random.randn(3, 4)), - "xonly": ("x", np.random.randn(3)), - "yonly": ("y", np.random.randn(4)), - "letters": ("y", ["a", "a", "b", "b"]), - } - ) - - expected = data.mean("y") - expected["yonly"] = expected["yonly"].variable.set_dims({"x": 3}) - actual = data.groupby("x").mean(...) - assert_allclose(expected, actual) - - actual = data.groupby("x").mean("y") - assert_allclose(expected, actual) - - letters = data["letters"] - expected = Dataset( - { - "xy": data["xy"].groupby(letters).mean(...), - "xonly": (data["xonly"].mean().variable.set_dims({"letters": 2})), - "yonly": data["yonly"].groupby(letters).mean(), - } - ) - actual = data.groupby("letters").mean(...) - assert_allclose(expected, actual) - - def test_groupby_math(self): - def reorder_dims(x): - return x.transpose("dim1", "dim2", "dim3", "time") - - ds = create_test_data() - ds["dim1"] = ds["dim1"] - for squeeze in [True, False]: - grouped = ds.groupby("dim1", squeeze=squeeze) - - expected = reorder_dims(ds + ds.coords["dim1"]) - actual = grouped + ds.coords["dim1"] - assert_identical(expected, reorder_dims(actual)) - - actual = ds.coords["dim1"] + grouped - assert_identical(expected, reorder_dims(actual)) - - ds2 = 2 * ds - expected = reorder_dims(ds + ds2) - actual = grouped + ds2 - assert_identical(expected, reorder_dims(actual)) - - actual = ds2 + grouped - assert_identical(expected, reorder_dims(actual)) - - grouped = ds.groupby("numbers") - zeros = DataArray([0, 0, 0, 0], [("numbers", range(4))]) - expected = (ds + Variable("dim3", np.zeros(10))).transpose( - "dim3", "dim1", "dim2", "time" - ) - actual = grouped + zeros - assert_equal(expected, actual) - - actual = zeros + grouped - assert_equal(expected, actual) - - with raises_regex(ValueError, "incompat.* grouped binary"): - grouped + ds - with raises_regex(ValueError, "incompat.* grouped binary"): - ds + grouped - with raises_regex(TypeError, "only support binary ops"): - grouped + 1 - with raises_regex(TypeError, "only support binary ops"): - grouped + grouped - with raises_regex(TypeError, "in-place operations"): - ds += grouped - - ds = Dataset( - { - "x": ("time", np.arange(100)), - "time": pd.date_range("2000-01-01", periods=100), - } - ) - with raises_regex(ValueError, "incompat.* grouped binary"): - ds + ds.groupby("time.month") - - def test_groupby_math_virtual(self): - ds = Dataset( - {"x": ("t", [1, 2, 3])}, {"t": pd.date_range("20100101", periods=3)} - ) - grouped = ds.groupby("t.day") - actual = grouped - grouped.mean(...) - expected = Dataset({"x": ("t", [0, 0, 0])}, ds[["t", "t.day"]]) - assert_identical(actual, expected) - - def test_groupby_nan(self): - # nan should be excluded from groupby - ds = Dataset({"foo": ("x", [1, 2, 3, 4])}, {"bar": ("x", [1, 1, 2, np.nan])}) - actual = ds.groupby("bar").mean(...) - expected = Dataset({"foo": ("bar", [1.5, 3]), "bar": [1, 2]}) - assert_identical(actual, expected) - - def test_groupby_order(self): - # groupby should preserve variables order - ds = Dataset() - for vn in ["a", "b", "c"]: - ds[vn] = DataArray(np.arange(10), dims=["t"]) - data_vars_ref = list(ds.data_vars.keys()) - ds = ds.groupby("t").mean(...) - data_vars = list(ds.data_vars.keys()) - assert data_vars == data_vars_ref - # coords are now at the end of the list, so the test below fails - # all_vars = list(ds.variables.keys()) - # all_vars_ref = list(ds.variables.keys()) - # self.assertEqual(all_vars, all_vars_ref) - def test_resample_and_first(self): times = pd.date_range("2000-01-01", freq="6H", periods=10) ds = Dataset( @@ -3888,6 +3826,11 @@ def test_resample_by_mean_with_keep_attrs(self): expected = ds.attrs assert expected == actual + with pytest.warns( + UserWarning, match="Passing ``keep_attrs`` to ``resample`` has no effect." + ): + ds.resample(time="1D", keep_attrs=True) + def test_resample_loffset(self): times = pd.date_range("2000-01-01", freq="6H", periods=10) ds = Dataset( @@ -3899,11 +3842,13 @@ def test_resample_loffset(self): ) ds.attrs["dsmeta"] = "dsdata" - actual = ds.resample(time="24H", loffset="-12H").mean("time").time - expected = xr.DataArray( - ds.bar.to_series().resample("24H", loffset="-12H").mean() - ).time - assert_identical(expected, actual) + # Our use of `loffset` may change if we align our API with pandas' changes. + # ref https://github.com/pydata/xarray/pull/4537 + actual = ds.resample(time="24H", loffset="-12H").mean().bar + expected_ = ds.bar.to_series().resample("24H").mean() + expected_.index += to_offset("-12H") + expected = DataArray.from_series(expected_) + assert_allclose(actual, expected) def test_resample_by_mean_discarding_attrs(self): times = pd.date_range("2000-01-01", freq="6H", periods=10) @@ -3975,13 +3920,13 @@ def test_resample_old_api(self): } ) - with raises_regex(TypeError, r"resample\(\) no longer supports"): + with pytest.raises(TypeError, match=r"resample\(\) no longer supports"): ds.resample("1D", "time") - with raises_regex(TypeError, r"resample\(\) no longer supports"): + with pytest.raises(TypeError, match=r"resample\(\) no longer supports"): ds.resample("1D", dim="time", how="mean") - with raises_regex(TypeError, r"resample\(\) no longer supports"): + with pytest.raises(TypeError, match=r"resample\(\) no longer supports"): ds.resample("1D", dim="time") def test_resample_ds_da_are_the_same(self): @@ -4190,7 +4135,7 @@ def test_from_dataframe_multiindex(self): assert_identical(actual, expected3) df_nonunique = df.iloc[[0, 0], :] - with raises_regex(ValueError, "non-unique MultiIndex"): + with pytest.raises(ValueError, match=r"non-unique MultiIndex"): Dataset.from_dataframe(df_nonunique) def test_from_dataframe_unsorted_levels(self): @@ -4213,7 +4158,7 @@ def test_from_dataframe_non_unique_columns(self): # regression test for GH449 df = pd.DataFrame(np.zeros((2, 2))) df.columns = ["foo", "foo"] - with raises_regex(ValueError, "non-unique columns"): + with pytest.raises(ValueError, match=r"non-unique columns"): Dataset.from_dataframe(df) def test_convert_dataframe_with_many_types_and_multiindex(self): @@ -4310,7 +4255,9 @@ def test_to_and_from_dict(self): "t": {"data": t, "dims": "t"}, "b": {"dims": "t", "data": y}, } - with raises_regex(ValueError, "cannot convert dict without the key 'dims'"): + with pytest.raises( + ValueError, match=r"cannot convert dict without the key 'dims'" + ): Dataset.from_dict(d) def test_to_and_from_dict_with_time_dim(self): @@ -4444,11 +4391,11 @@ def test_dropna(self): expected = ds.isel(a=[1, 3]) assert_identical(actual, ds) - with raises_regex(ValueError, "a single dataset dimension"): + with pytest.raises(ValueError, match=r"a single dataset dimension"): ds.dropna("foo") - with raises_regex(ValueError, "invalid how"): + with pytest.raises(ValueError, match=r"invalid how"): ds.dropna("a", how="somehow") - with raises_regex(TypeError, "must specify how or thresh"): + with pytest.raises(TypeError, match=r"must specify how or thresh"): ds.dropna("a", how=None) def test_fillna(self): @@ -4493,7 +4440,7 @@ def test_fillna(self): assert_identical(expected, actual) # but new data variables is not okay - with raises_regex(ValueError, "must be contained"): + with pytest.raises(ValueError, match=r"must be contained"): ds.fillna({"x": 0}) # empty argument should be OK @@ -4625,13 +4572,13 @@ def test_where_other(self): actual = ds.where(lambda x: x > 1, -1) assert_equal(expected, actual) - with raises_regex(ValueError, "cannot set"): + with pytest.raises(ValueError, match=r"cannot set"): ds.where(ds > 1, other=0, drop=True) - with raises_regex(ValueError, "indexes .* are not equal"): + with pytest.raises(ValueError, match=r"indexes .* are not equal"): ds.where(ds > 1, ds.isel(x=slice(3))) - with raises_regex(ValueError, "exact match required"): + with pytest.raises(ValueError, match=r"exact match required"): ds.where(ds > 1, ds.assign(b=2)) def test_where_drop(self): @@ -4654,7 +4601,7 @@ def test_where_drop(self): actual = ds.where(ds.a > 1, drop=True) assert_identical(expected, actual) - with raises_regex(TypeError, "must be a"): + with pytest.raises(TypeError, match=r"must be a"): ds.where(np.arange(5) > 1, drop=True) # 1d with odd coordinates @@ -4736,16 +4683,19 @@ def test_reduce(self): assert_equal(data.min(dim=["dim1"]), data.min(dim="dim1")) for reduct, expected in [ - ("dim2", ["dim1", "dim3", "time"]), - (["dim2", "time"], ["dim1", "dim3"]), - (("dim2", "time"), ["dim1", "dim3"]), - ((), ["dim1", "dim2", "dim3", "time"]), + ("dim2", ["dim3", "time", "dim1"]), + (["dim2", "time"], ["dim3", "dim1"]), + (("dim2", "time"), ["dim3", "dim1"]), + ((), ["dim2", "dim3", "time", "dim1"]), ]: actual = list(data.min(dim=reduct).dims) assert actual == expected assert_equal(data.mean(dim=[]), data) + with pytest.raises(ValueError): + data.mean(axis=0) + def test_reduce_coords(self): # regression test for GH1470 data = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"b": 4}) @@ -4772,7 +4722,7 @@ def test_mean_uint_dtype(self): def test_reduce_bad_dim(self): data = create_test_data() - with raises_regex(ValueError, "Dataset does not contain"): + with pytest.raises(ValueError, match=r"Dataset does not contain"): data.mean(dim="bad_dim") def test_reduce_cumsum(self): @@ -4786,34 +4736,38 @@ def test_reduce_cumsum(self): ) assert_identical(expected, data.cumsum()) - def test_reduce_cumsum_test_dims(self): + @pytest.mark.parametrize( + "reduct, expected", + [ + ("dim1", ["dim2", "dim3", "time", "dim1"]), + ("dim2", ["dim3", "time", "dim1", "dim2"]), + ("dim3", ["dim2", "time", "dim1", "dim3"]), + ("time", ["dim2", "dim3", "dim1"]), + ], + ) + @pytest.mark.parametrize("func", ["cumsum", "cumprod"]) + def test_reduce_cumsum_test_dims(self, reduct, expected, func): data = create_test_data() - for cumfunc in ["cumsum", "cumprod"]: - with raises_regex(ValueError, "Dataset does not contain"): - getattr(data, cumfunc)(dim="bad_dim") - - # ensure dimensions are correct - for reduct, expected in [ - ("dim1", ["dim1", "dim2", "dim3", "time"]), - ("dim2", ["dim1", "dim2", "dim3", "time"]), - ("dim3", ["dim1", "dim2", "dim3", "time"]), - ("time", ["dim1", "dim2", "dim3"]), - ]: - actual = getattr(data, cumfunc)(dim=reduct).dims - assert list(actual) == expected + with pytest.raises(ValueError, match=r"Dataset does not contain"): + getattr(data, func)(dim="bad_dim") + + # ensure dimensions are correct + actual = getattr(data, func)(dim=reduct).dims + assert list(actual) == expected def test_reduce_non_numeric(self): data1 = create_test_data(seed=44) data2 = create_test_data(seed=44) - add_vars = {"var4": ["dim1", "dim2"]} + add_vars = {"var4": ["dim1", "dim2"], "var5": ["dim1"]} for v, dims in sorted(add_vars.items()): size = tuple(data1.dims[d] for d in dims) data = np.random.randint(0, 100, size=size).astype(np.str_) data1[v] = (dims, data, {"foo": "variable"}) - assert "var4" not in data1.mean() + assert "var4" not in data1.mean() and "var5" not in data1.mean() assert_equal(data1.mean(), data2.mean()) assert_equal(data1.mean(dim="dim1"), data2.mean(dim="dim1")) + assert "var4" not in data1.mean(dim="dim2") and "var5" in data1.mean(dim="dim2") @pytest.mark.filterwarnings( "ignore:Once the behaviour of DataArray:DeprecationWarning" @@ -4923,12 +4877,11 @@ def mean_only_one_axis(x, axis): actual = ds.reduce(mean_only_one_axis, "y") assert_identical(expected, actual) - with raises_regex(TypeError, "missing 1 required positional argument: 'axis'"): + with pytest.raises( + TypeError, match=r"missing 1 required positional argument: 'axis'" + ): ds.reduce(mean_only_one_axis) - with raises_regex(TypeError, "non-integer axis"): - ds.reduce(mean_only_one_axis, axis=["x", "y"]) - def test_reduce_no_axis(self): def total_sum(x): return np.sum(x.flatten()) @@ -4938,10 +4891,7 @@ def total_sum(x): actual = ds.reduce(total_sum) assert_identical(expected, actual) - with raises_regex(TypeError, "unexpected keyword argument 'axis'"): - ds.reduce(total_sum, axis=0) - - with raises_regex(TypeError, "unexpected keyword argument 'axis'"): + with pytest.raises(TypeError, match=r"unexpected keyword argument 'axis'"): ds.reduce(total_sum, dim="x") def test_reduce_keepdims(self): @@ -4959,13 +4909,13 @@ def test_reduce_keepdims(self): # Coordinates involved in the reduction should be removed actual = ds.mean(keepdims=True) expected = Dataset( - {"a": (["x", "y"], np.mean(ds.a, keepdims=True))}, coords={"c": ds.c} + {"a": (["x", "y"], np.mean(ds.a, keepdims=True).data)}, coords={"c": ds.c} ) assert_identical(expected, actual) actual = ds.mean("x", keepdims=True) expected = Dataset( - {"a": (["x", "y"], np.mean(ds.a, axis=0, keepdims=True))}, + {"a": (["x", "y"], np.mean(ds.a, axis=0, keepdims=True).data)}, coords={"y": ds.y, "c": ds.c}, ) assert_identical(expected, actual) @@ -5019,9 +4969,15 @@ def test_rank(self): assert list(z.coords) == list(ds.coords) assert list(x.coords) == list(y.coords) # invalid dim - with raises_regex(ValueError, "does not contain"): + with pytest.raises(ValueError, match=r"does not contain"): x.rank("invalid_dim") + def test_rank_use_bottleneck(self): + ds = Dataset({"a": ("x", [0, np.nan, 2]), "b": ("y", [4, 6, 3, 4])}) + with xr.set_options(use_bottleneck=False): + with pytest.raises(RuntimeError): + ds.rank("x") + def test_count(self): ds = Dataset({"x": ("a", [np.nan, 1]), "y": 0, "z": np.nan}) expected = Dataset({"x": 1, "y": 1, "z": 0}) @@ -5186,7 +5142,7 @@ def test_dataset_math_errors(self): ds["foo"] += ds with pytest.raises(TypeError): ds["foo"].variable += ds - with raises_regex(ValueError, "must have the same"): + with pytest.raises(ValueError, match=r"must have the same"): ds += ds[["bar"]] # verify we can rollback in-place operations if something goes wrong @@ -5248,10 +5204,19 @@ def test_dataset_transpose(self): expected_dims = tuple(d for d in new_order if d in ds[k].dims) assert actual[k].dims == expected_dims - with raises_regex(ValueError, "permuted"): - ds.transpose("dim1", "dim2", "dim3") - with raises_regex(ValueError, "permuted"): - ds.transpose("dim1", "dim2", "dim3", "time", "extra_dim") + # test missing dimension, raise error + with pytest.raises(ValueError): + ds.transpose(..., "not_a_dim") + + # test missing dimension, ignore error + actual = ds.transpose(..., "not_a_dim", missing_dims="ignore") + expected_ell = ds.transpose(...) + assert_identical(expected_ell, actual) + + # test missing dimension, raise warning + with pytest.warns(UserWarning): + actual = ds.transpose(..., "not_a_dim", missing_dims="warn") + assert_identical(expected_ell, actual) assert "T" not in dir(ds) @@ -5332,12 +5297,12 @@ def test_dataset_diff_n2(self): def test_dataset_diff_exception_n_neg(self): ds = create_test_data(seed=1) - with raises_regex(ValueError, "must be non-negative"): + with pytest.raises(ValueError, match=r"must be non-negative"): ds.diff("dim2", n=-1) def test_dataset_diff_exception_label_str(self): ds = create_test_data(seed=1) - with raises_regex(ValueError, "'label' argument has to"): + with pytest.raises(ValueError, match=r"'label' argument has to"): ds.diff("dim2", label="raise_me") @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {"foo": -10}]) @@ -5355,7 +5320,7 @@ def test_shift(self, fill_value): expected = Dataset({"foo": ("x", [fill_value, 1, 2])}, coords, attrs) assert_identical(expected, actual) - with raises_regex(ValueError, "dimensions"): + with pytest.raises(ValueError, match=r"dimensions"): ds.shift(foo=123) def test_roll_coords(self): @@ -5368,7 +5333,7 @@ def test_roll_coords(self): expected = Dataset({"foo": ("x", [3, 1, 2])}, ex_coords, attrs) assert_identical(expected, actual) - with raises_regex(ValueError, "dimensions"): + with pytest.raises(ValueError, match=r"dimensions"): ds.roll(foo=123, roll_coords=True) def test_roll_no_coords(self): @@ -5380,7 +5345,7 @@ def test_roll_no_coords(self): expected = Dataset({"foo": ("x", [3, 1, 2])}, coords, attrs) assert_identical(expected, actual) - with raises_regex(ValueError, "dimensions"): + with pytest.raises(ValueError, match=r"dimensions"): ds.roll(abc=321, roll_coords=False) def test_roll_coords_none(self): @@ -5420,11 +5385,11 @@ def test_real_and_imag(self): def test_setattr_raises(self): ds = Dataset({}, coords={"scalar": 1}, attrs={"foo": "bar"}) - with raises_regex(AttributeError, "cannot set attr"): + with pytest.raises(AttributeError, match=r"cannot set attr"): ds.scalar = 2 - with raises_regex(AttributeError, "cannot set attr"): + with pytest.raises(AttributeError, match=r"cannot set attr"): ds.foo = 2 - with raises_regex(AttributeError, "cannot set attr"): + with pytest.raises(AttributeError, match=r"cannot set attr"): ds.other = 2 def test_filter_by_attrs(self): @@ -5498,8 +5463,8 @@ def test_binary_op_propagate_indexes(self): ds = Dataset( {"d1": DataArray([1, 2, 3], dims=["x"], coords={"x": [10, 20, 30]})} ) - expected = ds.indexes["x"] - actual = (ds * 2).indexes["x"] + expected = ds.xindexes["x"] + actual = (ds * 2).xindexes["x"] assert expected is actual def test_binary_op_join_setting(self): @@ -5807,17 +5772,141 @@ def test_astype_attrs(self): assert not data.astype(float, keep_attrs=False).attrs assert not data.astype(float, keep_attrs=False).var1.attrs + @pytest.mark.parametrize("parser", ["pandas", "python"]) + @pytest.mark.parametrize( + "engine", ["python", None, pytest.param("numexpr", marks=[requires_numexpr])] + ) + @pytest.mark.parametrize( + "backend", ["numpy", pytest.param("dask", marks=[requires_dask])] + ) + def test_query(self, backend, engine, parser): + """Test querying a dataset.""" + + # setup test data + np.random.seed(42) + a = np.arange(0, 10, 1) + b = np.random.randint(0, 100, size=10) + c = np.linspace(0, 1, 20) + d = np.random.choice(["foo", "bar", "baz"], size=30, replace=True).astype( + object + ) + e = np.arange(0, 10 * 20).reshape(10, 20) + f = np.random.normal(0, 1, size=(10, 20, 30)) + if backend == "numpy": + ds = Dataset( + { + "a": ("x", a), + "b": ("x", b), + "c": ("y", c), + "d": ("z", d), + "e": (("x", "y"), e), + "f": (("x", "y", "z"), f), + } + ) + elif backend == "dask": + ds = Dataset( + { + "a": ("x", da.from_array(a, chunks=3)), + "b": ("x", da.from_array(b, chunks=3)), + "c": ("y", da.from_array(c, chunks=7)), + "d": ("z", da.from_array(d, chunks=12)), + "e": (("x", "y"), da.from_array(e, chunks=(3, 7))), + "f": (("x", "y", "z"), da.from_array(f, chunks=(3, 7, 12))), + } + ) + + # query single dim, single variable + actual = ds.query(x="a > 5", engine=engine, parser=parser) + expect = ds.isel(x=(a > 5)) + assert_identical(expect, actual) + + # query single dim, single variable, via dict + actual = ds.query(dict(x="a > 5"), engine=engine, parser=parser) + expect = ds.isel(dict(x=(a > 5))) + assert_identical(expect, actual) + + # query single dim, single variable + actual = ds.query(x="b > 50", engine=engine, parser=parser) + expect = ds.isel(x=(b > 50)) + assert_identical(expect, actual) + + # query single dim, single variable + actual = ds.query(y="c < .5", engine=engine, parser=parser) + expect = ds.isel(y=(c < 0.5)) + assert_identical(expect, actual) + + # query single dim, single string variable + if parser == "pandas": + # N.B., this query currently only works with the pandas parser + # xref https://github.com/pandas-dev/pandas/issues/40436 + actual = ds.query(z='d == "bar"', engine=engine, parser=parser) + expect = ds.isel(z=(d == "bar")) + assert_identical(expect, actual) + + # query single dim, multiple variables + actual = ds.query(x="(a > 5) & (b > 50)", engine=engine, parser=parser) + expect = ds.isel(x=((a > 5) & (b > 50))) + assert_identical(expect, actual) + + # query single dim, multiple variables with computation + actual = ds.query(x="(a * b) > 250", engine=engine, parser=parser) + expect = ds.isel(x=(a * b) > 250) + assert_identical(expect, actual) + + # check pandas query syntax is supported + if parser == "pandas": + actual = ds.query(x="(a > 5) and (b > 50)", engine=engine, parser=parser) + expect = ds.isel(x=((a > 5) & (b > 50))) + assert_identical(expect, actual) + + # query multiple dims via kwargs + actual = ds.query(x="a > 5", y="c < .5", engine=engine, parser=parser) + expect = ds.isel(x=(a > 5), y=(c < 0.5)) + assert_identical(expect, actual) + + # query multiple dims via kwargs + if parser == "pandas": + actual = ds.query( + x="a > 5", y="c < .5", z="d == 'bar'", engine=engine, parser=parser + ) + expect = ds.isel(x=(a > 5), y=(c < 0.5), z=(d == "bar")) + assert_identical(expect, actual) + + # query multiple dims via dict + actual = ds.query(dict(x="a > 5", y="c < .5"), engine=engine, parser=parser) + expect = ds.isel(dict(x=(a > 5), y=(c < 0.5))) + assert_identical(expect, actual) + + # query multiple dims via dict + if parser == "pandas": + actual = ds.query( + dict(x="a > 5", y="c < .5", z="d == 'bar'"), + engine=engine, + parser=parser, + ) + expect = ds.isel(dict(x=(a > 5), y=(c < 0.5), z=(d == "bar"))) + assert_identical(expect, actual) -# Py.test tests + # test error handling + with pytest.raises(ValueError): + ds.query("a > 5") # must be dict or kwargs + with pytest.raises(ValueError): + ds.query(x=(a > 5)) # must be query string + with pytest.raises(IndexError): + ds.query(y="a > 5") # wrong length dimension + with pytest.raises(IndexError): + ds.query(x="c < .5") # wrong length dimension + with pytest.raises(IndexError): + ds.query(x="e > 100") # wrong number of dimensions + with pytest.raises(UndefinedVariableError): + ds.query(x="spam > 50") # name not present -@pytest.fixture(params=[None]) -def data_set(request): - return create_test_data(request.param) +# pytest tests — new tests should go here, rather than in the class. @pytest.mark.parametrize("test_elements", ([1, 2], np.array([1, 2]), DataArray([1, 2]))) -def test_isin(test_elements): +def test_isin(test_elements, backend): expected = Dataset( data_vars={ "var1": (("dim1",), [0, 1]), @@ -5826,6 +5915,9 @@ def test_isin(test_elements): } ).astype("bool") + if backend == "dask": + expected = expected.chunk() + result = Dataset( data_vars={ "var1": (("dim1",), [0, 1]), @@ -5837,33 +5929,6 @@ def test_isin(test_elements): assert_equal(result, expected) -@pytest.mark.skipif(not has_dask, reason="requires dask") -@pytest.mark.parametrize("test_elements", ([1, 2], np.array([1, 2]), DataArray([1, 2]))) -def test_isin_dask(test_elements): - expected = Dataset( - data_vars={ - "var1": (("dim1",), [0, 1]), - "var2": (("dim1",), [1, 1]), - "var3": (("dim1",), [0, 1]), - } - ).astype("bool") - - result = ( - Dataset( - data_vars={ - "var1": (("dim1",), [0, 1]), - "var2": (("dim1",), [1, 2]), - "var3": (("dim1",), [0, 1]), - } - ) - .chunk(1) - .isin(test_elements) - .compute() - ) - - assert_equal(result, expected) - - def test_isin_dataset(): ds = Dataset({"x": [1, 2]}) with pytest.raises(TypeError): @@ -5910,17 +5975,18 @@ def test_constructor_raises_with_invalid_coords(unaligned_coords): xr.DataArray([1, 2, 3], dims=["x"], coords=unaligned_coords) -def test_dir_expected_attrs(data_set): +@pytest.mark.parametrize("ds", [3], indirect=True) +def test_dir_expected_attrs(ds): some_expected_attrs = {"pipe", "mean", "isnull", "var1", "dim2", "numbers"} - result = dir(data_set) + result = dir(ds) assert set(result) >= some_expected_attrs -def test_dir_non_string(data_set): +def test_dir_non_string(ds): # add a numbered key to ensure this doesn't break dir - data_set[5] = "foo" - result = dir(data_set) + ds[5] = "foo" + result = dir(ds) assert 5 not in result # GH2172 @@ -5930,129 +5996,50 @@ def test_dir_non_string(data_set): dir(x2) -def test_dir_unicode(data_set): - data_set["unicode"] = "uni" - result = dir(data_set) +def test_dir_unicode(ds): + ds["unicode"] = "uni" + result = dir(ds) assert "unicode" in result @pytest.fixture(params=[1]) -def ds(request): +def ds(request, backend): if request.param == 1: - return Dataset( - { - "z1": (["y", "x"], np.random.randn(2, 8)), - "z2": (["time", "y"], np.random.randn(10, 2)), - }, - { - "x": ("x", np.linspace(0, 1.0, 8)), - "time": ("time", np.linspace(0, 1.0, 10)), - "c": ("y", ["a", "b"]), - "y": range(2), - }, + ds = Dataset( + dict( + z1=(["y", "x"], np.random.randn(2, 8)), + z2=(["time", "y"], np.random.randn(10, 2)), + ), + dict( + x=("x", np.linspace(0, 1.0, 8)), + time=("time", np.linspace(0, 1.0, 10)), + c=("y", ["a", "b"]), + y=range(2), + ), ) - - if request.param == 2: - return Dataset( - { - "z1": (["time", "y"], np.random.randn(10, 2)), - "z2": (["time"], np.random.randn(10)), - "z3": (["x", "time"], np.random.randn(8, 10)), - }, - { - "x": ("x", np.linspace(0, 1.0, 8)), - "time": ("time", np.linspace(0, 1.0, 10)), - "c": ("y", ["a", "b"]), - "y": range(2), - }, + elif request.param == 2: + ds = Dataset( + dict( + z1=(["time", "y"], np.random.randn(10, 2)), + z2=(["time"], np.random.randn(10)), + z3=(["x", "time"], np.random.randn(8, 10)), + ), + dict( + x=("x", np.linspace(0, 1.0, 8)), + time=("time", np.linspace(0, 1.0, 10)), + c=("y", ["a", "b"]), + y=range(2), + ), ) + elif request.param == 3: + ds = create_test_data() + else: + raise ValueError + if backend == "dask": + return ds.chunk() -def test_coarsen_absent_dims_error(ds): - with raises_regex(ValueError, "not found in Dataset."): - ds.coarsen(foo=2) - - -@pytest.mark.parametrize("dask", [True, False]) -@pytest.mark.parametrize(("boundary", "side"), [("trim", "left"), ("pad", "right")]) -def test_coarsen(ds, dask, boundary, side): - if dask and has_dask: - ds = ds.chunk({"x": 4}) - - actual = ds.coarsen(time=2, x=3, boundary=boundary, side=side).max() - assert_equal( - actual["z1"], ds["z1"].coarsen(x=3, boundary=boundary, side=side).max() - ) - # coordinate should be mean by default - assert_equal( - actual["time"], ds["time"].coarsen(time=2, boundary=boundary, side=side).mean() - ) - - -@pytest.mark.parametrize("dask", [True, False]) -def test_coarsen_coords(ds, dask): - if dask and has_dask: - ds = ds.chunk({"x": 4}) - - # check if coord_func works - actual = ds.coarsen(time=2, x=3, boundary="trim", coord_func={"time": "max"}).max() - assert_equal(actual["z1"], ds["z1"].coarsen(x=3, boundary="trim").max()) - assert_equal(actual["time"], ds["time"].coarsen(time=2, boundary="trim").max()) - - # raise if exact - with pytest.raises(ValueError): - ds.coarsen(x=3).mean() - # should be no error - ds.isel(x=slice(0, 3 * (len(ds["x"]) // 3))).coarsen(x=3).mean() - - # working test with pd.time - da = xr.DataArray( - np.linspace(0, 365, num=364), - dims="time", - coords={"time": pd.date_range("15/12/1999", periods=364)}, - ) - actual = da.coarsen(time=2).mean() - - -@requires_cftime -def test_coarsen_coords_cftime(): - times = xr.cftime_range("2000", periods=6) - da = xr.DataArray(range(6), [("time", times)]) - actual = da.coarsen(time=3).mean() - expected_times = xr.cftime_range("2000-01-02", freq="3D", periods=2) - np.testing.assert_array_equal(actual.time, expected_times) - - -def test_coarsen_keep_attrs(): - _attrs = {"units": "test", "long_name": "testing"} - - var1 = np.linspace(10, 15, 100) - var2 = np.linspace(5, 10, 100) - coords = np.linspace(1, 10, 100) - - ds = Dataset( - data_vars={"var1": ("coord", var1), "var2": ("coord", var2)}, - coords={"coord": coords}, - attrs=_attrs, - ) - - ds2 = ds.copy(deep=True) - - # Test dropped attrs - dat = ds.coarsen(coord=5).mean() - assert dat.attrs == {} - - # Test kept attrs using dataset keyword - dat = ds.coarsen(coord=5, keep_attrs=True).mean() - assert dat.attrs == _attrs - - # Test kept attrs using global option - with set_options(keep_attrs=True): - dat = ds.coarsen(coord=5).mean() - assert dat.attrs == _attrs - - # Test kept attrs in original object - xr.testing.assert_identical(ds, ds2) + return ds @pytest.mark.parametrize( @@ -6131,41 +6118,6 @@ def test_rolling_keep_attrs(funcname, argument): assert result.da_not_rolled.name == "da_not_rolled" -def test_rolling_keep_attrs_deprecated(): - global_attrs = {"units": "test", "long_name": "testing"} - attrs_da = {"da_attr": "test"} - - data = np.linspace(10, 15, 100) - coords = np.linspace(1, 10, 100) - - ds = Dataset( - data_vars={"da": ("coord", data)}, - coords={"coord": coords}, - attrs=global_attrs, - ) - ds.da.attrs = attrs_da - - # deprecated option - with pytest.warns( - FutureWarning, match="Passing ``keep_attrs`` to ``rolling`` is deprecated" - ): - result = ds.rolling(dim={"coord": 5}, keep_attrs=False).construct("window_dim") - - assert result.attrs == {} - assert result.da.attrs == {} - - # the keep_attrs in the reduction function takes precedence - with pytest.warns( - FutureWarning, match="Passing ``keep_attrs`` to ``rolling`` is deprecated" - ): - result = ds.rolling(dim={"coord": 5}, keep_attrs=True).construct( - "window_dim", keep_attrs=False - ) - - assert result.attrs == {} - assert result.da.attrs == {} - - def test_rolling_properties(ds): # catching invalid args with pytest.raises(ValueError, match="window must be > 0"): @@ -6180,6 +6132,7 @@ def test_rolling_properties(ds): @pytest.mark.parametrize("center", (True, False, None)) @pytest.mark.parametrize("min_periods", (1, None)) @pytest.mark.parametrize("key", ("z1", "z2")) +@pytest.mark.parametrize("backend", ["numpy"], indirect=True) def test_rolling_wrapped_bottleneck(ds, name, center, min_periods, key): bn = pytest.importorskip("bottleneck", minversion="1.1") @@ -6205,6 +6158,7 @@ def test_rolling_wrapped_bottleneck(ds, name, center, min_periods, key): @requires_numbagg +@pytest.mark.parametrize("backend", ["numpy"], indirect=True) def test_rolling_exp(ds): result = ds.rolling_exp(time=10, window_type="span").mean() @@ -6212,6 +6166,7 @@ def test_rolling_exp(ds): @requires_numbagg +@pytest.mark.parametrize("backend", ["numpy"], indirect=True) def test_rolling_exp_keep_attrs(ds): attrs_global = {"attrs": "global"} @@ -6247,6 +6202,11 @@ def test_rolling_exp_keep_attrs(ds): assert result.attrs == {} assert result.z1.attrs == {} + with pytest.warns( + UserWarning, match="Passing ``keep_attrs`` to ``rolling_exp`` has no effect." + ): + ds.rolling_exp(time=10, keep_attrs=True) + @pytest.mark.parametrize("center", (True, False)) @pytest.mark.parametrize("min_periods", (None, 1, 2, 3)) @@ -6603,8 +6563,69 @@ def test_integrate(dask): with pytest.raises(ValueError): da.integrate("x2d") - with pytest.warns(FutureWarning): - da.integrate(dim="x") + +@requires_scipy +@pytest.mark.parametrize("dask", [True, False]) +def test_cumulative_integrate(dask): + rs = np.random.RandomState(43) + coord = [0.2, 0.35, 0.4, 0.6, 0.7, 0.75, 0.76, 0.8] + + da = xr.DataArray( + rs.randn(8, 6), + dims=["x", "y"], + coords={ + "x": coord, + "x2": (("x",), rs.randn(8)), + "z": 3, + "x2d": (("x", "y"), rs.randn(8, 6)), + }, + ) + if dask and has_dask: + da = da.chunk({"x": 4}) + + ds = xr.Dataset({"var": da}) + + # along x + actual = da.cumulative_integrate("x") + + # From scipy-1.6.0 cumtrapz is renamed to cumulative_trapezoid, but cumtrapz is + # still provided for backward compatibility + from scipy.integrate import cumtrapz + + expected_x = xr.DataArray( + cumtrapz(da.compute(), da["x"], axis=0, initial=0.0), + dims=["x", "y"], + coords=da.coords, + ) + assert_allclose(expected_x, actual.compute()) + assert_equal( + ds["var"].cumulative_integrate("x"), + ds.cumulative_integrate("x")["var"], + ) + + # make sure result is also a dask array (if the source is dask array) + assert isinstance(actual.data, type(da.data)) + + # along y + actual = da.cumulative_integrate("y") + expected_y = xr.DataArray( + cumtrapz(da, da["y"], axis=1, initial=0.0), + dims=["x", "y"], + coords=da.coords, + ) + assert_allclose(expected_y, actual.compute()) + assert_equal(actual, ds.cumulative_integrate("y")["var"]) + assert_equal( + ds["var"].cumulative_integrate("y"), + ds.cumulative_integrate("y")["var"], + ) + + # along x and y + actual = da.cumulative_integrate(("y", "x")) + assert actual.ndim == 2 + + with pytest.raises(ValueError): + da.cumulative_integrate("x2d") @pytest.mark.parametrize("dask", [True, False]) @@ -6641,7 +6662,7 @@ def test_trapz_datetime(dask, which_datetime): actual = da.integrate("time", datetime_unit="D") expected_data = np.trapz( - da.data, + da.compute().data, duck_array_ops.datetime_to_numeric(da["time"].data, datetime_unit="D"), axis=0, ) @@ -6695,3 +6716,89 @@ def test_deepcopy_obj_array(): x0 = Dataset(dict(foo=DataArray(np.array([object()])))) x1 = deepcopy(x0) assert x0["foo"].values[0] is not x1["foo"].values[0] + + +def test_clip(ds): + result = ds.clip(min=0.5) + assert result.min(...) >= 0.5 + + result = ds.clip(max=0.5) + assert result.max(...) <= 0.5 + + result = ds.clip(min=0.25, max=0.75) + assert result.min(...) >= 0.25 + assert result.max(...) <= 0.75 + + result = ds.clip(min=ds.mean("y"), max=ds.mean("y")) + assert result.dims == ds.dims + + +class TestNumpyCoercion: + def test_from_numpy(self): + ds = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"lat": ("x", [4, 5, 6])}) + + assert_identical(ds.as_numpy(), ds) + + @requires_dask + def test_from_dask(self): + ds = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"lat": ("x", [4, 5, 6])}) + ds_chunked = ds.chunk(1) + + assert_identical(ds_chunked.as_numpy(), ds.compute()) + + @requires_pint_0_15 + def test_from_pint(self): + from pint import Quantity + + arr = np.array([1, 2, 3]) + ds = xr.Dataset( + {"a": ("x", Quantity(arr, units="Pa"))}, + coords={"lat": ("x", Quantity(arr + 3, units="m"))}, + ) + + expected = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"lat": ("x", arr + 3)}) + assert_identical(ds.as_numpy(), expected) + + @requires_sparse + def test_from_sparse(self): + import sparse + + arr = np.diagflat([1, 2, 3]) + sparr = sparse.COO.from_numpy(arr) + ds = xr.Dataset( + {"a": (["x", "y"], sparr)}, coords={"elev": (("x", "y"), sparr + 3)} + ) + + expected = xr.Dataset( + {"a": (["x", "y"], arr)}, coords={"elev": (("x", "y"), arr + 3)} + ) + assert_identical(ds.as_numpy(), expected) + + @requires_cupy + def test_from_cupy(self): + import cupy as cp + + arr = np.array([1, 2, 3]) + ds = xr.Dataset( + {"a": ("x", cp.array(arr))}, coords={"lat": ("x", cp.array(arr + 3))} + ) + + expected = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"lat": ("x", arr + 3)}) + assert_identical(ds.as_numpy(), expected) + + @requires_dask + @requires_pint_0_15 + def test_from_pint_wrapping_dask(self): + import dask + from pint import Quantity + + arr = np.array([1, 2, 3]) + d = dask.array.from_array(arr) + ds = xr.Dataset( + {"a": ("x", Quantity(d, units="Pa"))}, + coords={"lat": ("x", Quantity(d, units="m") * 2)}, + ) + + result = ds.as_numpy() + expected = xr.Dataset({"a": ("x", arr)}, coords={"lat": ("x", arr * 2)}) + assert_identical(result, expected) diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 7886e9fd0d4..ab0d1d9f22c 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -171,6 +171,7 @@ def test_dask_distributed_rasterio_integration_test(loop): @requires_cfgrib +@pytest.mark.filterwarnings("ignore:deallocating CachingFileManager") def test_dask_distributed_cfgrib_integration_test(loop): with cluster() as (s, [a, b]): with Client(s["address"], loop=loop): @@ -183,11 +184,7 @@ def test_dask_distributed_cfgrib_integration_test(loop): assert_allclose(actual, expected) -@pytest.mark.skipif( - distributed.__version__ <= "1.19.3", - reason="Need recent distributed version to clean up get", -) -@gen_cluster(client=True, timeout=None) +@gen_cluster(client=True) async def test_async(c, s, a, b): x = create_test_data() assert not dask.is_dask_collection(x) diff --git a/xarray/tests/test_dtypes.py b/xarray/tests/test_dtypes.py index 5ad1a6355e6..53ed2c87133 100644 --- a/xarray/tests/test_dtypes.py +++ b/xarray/tests/test_dtypes.py @@ -90,3 +90,9 @@ def test_maybe_promote(kind, expected): actual = dtypes.maybe_promote(np.dtype(kind)) assert actual[0] == expected[0] assert str(actual[1]) == expected[1] + + +def test_nat_types_membership(): + assert np.datetime64("NaT").dtype in dtypes.NAT_TYPES + assert np.timedelta64("NaT").dtype in dtypes.NAT_TYPES + assert np.float64 not in dtypes.NAT_TYPES diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index 1342950f3e5..6d49e20909d 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -20,21 +20,22 @@ mean, np_timedelta64_to_float, pd_timedelta_to_float, + push, py_timedelta_to_float, - rolling_window, stack, timedelta_to_numeric, where, ) from xarray.core.pycompat import dask_array_type -from xarray.testing import assert_allclose, assert_equal +from xarray.testing import assert_allclose, assert_equal, assert_identical from . import ( arm_xfail, assert_array_equal, has_dask, has_scipy, - raises_regex, + raise_if_dask_computes, + requires_bottleneck, requires_cftime, requires_dask, ) @@ -72,7 +73,7 @@ def test_first(self): actual = first(self.x, axis=-1, skipna=False) assert_array_equal(expected, actual) - with raises_regex(IndexError, "out of bounds"): + with pytest.raises(IndexError, match=r"out of bounds"): first(self.x, 3) def test_last(self): @@ -93,7 +94,7 @@ def test_last(self): actual = last(self.x, axis=-1, skipna=False) assert_array_equal(expected, actual) - with raises_regex(IndexError, "out of bounds"): + with pytest.raises(IndexError, match=r"out of bounds"): last(self.x, 3) def test_count(self): @@ -284,15 +285,15 @@ def assert_dask_array(da, dask): def test_datetime_mean(dask): # Note: only testing numpy, as dask is broken upstream da = DataArray( - np.array(["2010-01-01", "NaT", "2010-01-03", "NaT", "NaT"], dtype="M8"), + np.array(["2010-01-01", "NaT", "2010-01-03", "NaT", "NaT"], dtype="M8[ns]"), dims=["time"], ) if dask: # Trigger use case where a chunk is full of NaT da = da.chunk({"time": 3}) - expect = DataArray(np.array("2010-01-02", dtype="M8")) - expect_nat = DataArray(np.array("NaT", dtype="M8")) + expect = DataArray(np.array("2010-01-02", dtype="M8[ns]")) + expect_nat = DataArray(np.array("NaT", dtype="M8[ns]")) actual = da.mean() if dask: @@ -374,6 +375,17 @@ def test_cftime_datetime_mean_dask_error(): da.mean() +def test_empty_axis_dtype(): + ds = Dataset() + ds["pos"] = [1, 2, 3] + ds["data"] = ("pos", "time"), [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] + ds["var"] = "pos", [2, 3, 4] + assert_identical(ds.mean(dim="time")["var"], ds["var"]) + assert_identical(ds.max(dim="time")["var"], ds["var"]) + assert_identical(ds.min(dim="time")["var"], ds["var"]) + assert_identical(ds.sum(dim="time")["var"], ds["var"]) + + @pytest.mark.parametrize("dim_num", [1, 2]) @pytest.mark.parametrize("dtype", [float, int, np.float32, np.bool_]) @pytest.mark.parametrize("dask", [False, True]) @@ -530,32 +542,6 @@ def test_isnull_with_dask(): assert_equal(da.isnull().load(), da.load().isnull()) -@pytest.mark.skipif(not has_dask, reason="This is for dask.") -@pytest.mark.parametrize("axis", [0, -1]) -@pytest.mark.parametrize("window", [3, 8, 11]) -@pytest.mark.parametrize("center", [True, False]) -def test_dask_rolling(axis, window, center): - import dask.array as da - - x = np.array(np.random.randn(100, 40), dtype=float) - dx = da.from_array(x, chunks=[(6, 30, 30, 20, 14), 8]) - - expected = rolling_window( - x, axis=axis, window=window, center=center, fill_value=np.nan - ) - actual = rolling_window( - dx, axis=axis, window=window, center=center, fill_value=np.nan - ) - assert isinstance(actual, da.Array) - assert_array_equal(actual, expected) - assert actual.shape == expected.shape - - # we need to take care of window size if chunk size is small - # window/2 should be smaller than the smallest chunk size. - with pytest.raises(ValueError): - rolling_window(dx, axis=axis, window=100, center=center, fill_value=np.nan) - - @pytest.mark.skipif(not has_dask, reason="This is for dask.") @pytest.mark.parametrize("axis", [0, -1, 1]) @pytest.mark.parametrize("edge_order", [1, 2]) @@ -587,7 +573,10 @@ def test_min_count(dim_num, dtype, dask, func, aggdim, contains_nan, skipna): da = construct_dataarray(dim_num, dtype, contains_nan=contains_nan, dask=dask) min_count = 3 - actual = getattr(da, func)(dim=aggdim, skipna=skipna, min_count=min_count) + # If using Dask, the function call should be lazy. + with raise_if_dask_computes(): + actual = getattr(da, func)(dim=aggdim, skipna=skipna, min_count=min_count) + expected = series_reduce(da, func, skipna=skipna, dim=aggdim, min_count=min_count) assert_allclose(actual, expected) assert_dask_array(actual, dask) @@ -603,7 +592,13 @@ def test_min_count_nd(dtype, dask, func): min_count = 3 dim_num = 3 da = construct_dataarray(dim_num, dtype, contains_nan=True, dask=dask) - actual = getattr(da, func)(dim=["x", "y", "z"], skipna=True, min_count=min_count) + + # If using Dask, the function call should be lazy. + with raise_if_dask_computes(): + actual = getattr(da, func)( + dim=["x", "y", "z"], skipna=True, min_count=min_count + ) + # Supplying all dims is equivalent to supplying `...` or `None` expected = getattr(da, func)(dim=..., skipna=True, min_count=min_count) @@ -611,6 +606,48 @@ def test_min_count_nd(dtype, dask, func): assert_dask_array(actual, dask) +@pytest.mark.parametrize("dask", [False, True]) +@pytest.mark.parametrize("func", ["sum", "prod"]) +@pytest.mark.parametrize("dim", [None, "a", "b"]) +def test_min_count_specific(dask, func, dim): + if dask and not has_dask: + pytest.skip("requires dask") + + # Simple array with four non-NaN values. + da = DataArray(np.ones((6, 6), dtype=np.float64) * np.nan, dims=("a", "b")) + da[0][0] = 2 + da[0][3] = 2 + da[3][0] = 2 + da[3][3] = 2 + if dask: + da = da.chunk({"a": 3, "b": 3}) + + # Expected result if we set min_count to the number of non-NaNs in a + # row/column/the entire array. + if dim: + min_count = 2 + expected = DataArray( + [4.0, np.nan, np.nan] * 2, dims=("a" if dim == "b" else "b",) + ) + else: + min_count = 4 + expected = DataArray(8.0 if func == "sum" else 16.0) + + # Check for that min_count. + with raise_if_dask_computes(): + actual = getattr(da, func)(dim, skipna=True, min_count=min_count) + assert_dask_array(actual, dask) + assert_allclose(actual, expected) + + # With min_count being one higher, should get all NaN. + min_count += 1 + expected *= np.nan + with raise_if_dask_computes(): + actual = getattr(da, func)(dim, skipna=True, min_count=min_count) + assert_dask_array(actual, dask) + assert_allclose(actual, expected) + + @pytest.mark.parametrize("func", ["sum", "prod"]) def test_min_count_dataset(func): da = construct_dataarray(2, dtype=float, contains_nan=True, dask=False) @@ -655,9 +692,12 @@ def test_docs(): have a sentinel missing value (int) or skipna=True has not been implemented (object, datetime64 or timedelta64). min_count : int, default: None - The required number of valid values to perform the operation. - If fewer than min_count non-NA values are present the result will - be NA. New in version 0.10.8: Added with the default being None. + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. New in version 0.10.8: Added with the default being + None. Changed in version 0.17.0: if specified on an integer array + and skipna=True, the result will be a float array. keep_attrs : bool, optional If True, the attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be @@ -831,3 +871,24 @@ def test_least_squares(use_dask, skipna): np.testing.assert_allclose(coeffs, [1.5, 1.25]) np.testing.assert_allclose(residuals, [2.0]) + + +@requires_dask +@requires_bottleneck +def test_push_dask(): + import bottleneck + import dask.array + + array = np.array([np.nan, np.nan, np.nan, 1, 2, 3, np.nan, np.nan, 4, 5, np.nan, 6]) + expected = bottleneck.push(array, axis=0) + for c in range(1, 11): + with raise_if_dask_computes(): + actual = push(dask.array.from_array(array, chunks=c), axis=0, n=None) + np.testing.assert_equal(actual, expected) + + # some chunks of size-1 with NaN + with raise_if_dask_computes(): + actual = push( + dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)), axis=0, n=None + ) + np.testing.assert_equal(actual, expected) diff --git a/xarray/tests/test_extensions.py b/xarray/tests/test_extensions.py index fa91e5c813d..2d9fa11dda3 100644 --- a/xarray/tests/test_extensions.py +++ b/xarray/tests/test_extensions.py @@ -4,7 +4,7 @@ import xarray as xr -from . import assert_identical, raises_regex +from . import assert_identical @xr.register_dataset_accessor("example_accessor") @@ -84,5 +84,5 @@ class BrokenAccessor: def __init__(self, xarray_obj): raise AttributeError("broken") - with raises_regex(RuntimeError, "error initializing"): + with pytest.raises(RuntimeError, match=r"error initializing"): xr.Dataset().stupid_accessor diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index f2facf5b481..b9ba57f99dc 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -7,9 +7,8 @@ import xarray as xr from xarray.core import formatting -from xarray.core.npcompat import IS_NEP18_ACTIVE -from . import raises_regex +from . import requires_netCDF4 class TestFormatting: @@ -51,7 +50,7 @@ def test_first_n_items(self): expected = array.flat[:n] assert (expected == actual).all() - with raises_regex(ValueError, "at least one item"): + with pytest.raises(ValueError, match=r"at least one item"): formatting.first_n_items(array, 0) def test_last_n_items(self): @@ -61,7 +60,7 @@ def test_last_n_items(self): expected = array.flat[-n:] assert (expected == actual).all() - with raises_regex(ValueError, "at least one item"): + with pytest.raises(ValueError, match=r"at least one item"): formatting.first_n_items(array, 0) def test_last_item(self): @@ -394,8 +393,26 @@ def test_array_repr(self): assert actual == expected + with xr.set_options(display_expand_data=False): + actual = formatting.array_repr(ds[(1, 2)]) + expected = dedent( + """\ + + 0 + Dimensions without coordinates: test""" + ) + + assert actual == expected + + def test_array_repr_variable(self): + var = xr.Variable("x", [0, 1]) + + formatting.array_repr(var) + + with xr.set_options(display_expand_data=False): + formatting.array_repr(var) + -@pytest.mark.skipif(not IS_NEP18_ACTIVE, reason="requires __array_function__") def test_inline_variable_array_repr_custom_repr(): class CustomArray: def __init__(self, value, attr): @@ -465,6 +482,25 @@ def test_large_array_repr_length(): assert len(result) < 50 +@requires_netCDF4 +def test_repr_file_collapsed(tmp_path): + arr = xr.DataArray(np.arange(300), dims="test") + arr.to_netcdf(tmp_path / "test.nc", engine="netcdf4") + + with xr.open_dataarray(tmp_path / "test.nc") as arr, xr.set_options( + display_expand_data=False + ): + actual = formatting.array_repr(arr) + expected = dedent( + """\ + + array([ 0, 1, 2, ..., 297, 298, 299]) + Dimensions without coordinates: test""" + ) + + assert actual == expected + + @pytest.mark.parametrize( "display_max_rows, n_vars, n_attr", [(50, 40, 30), (35, 40, 30), (11, 40, 30), (1, 40, 30)], @@ -496,3 +532,19 @@ def test__mapping_repr(display_max_rows, n_vars, n_attr): len_summary = len(summary) data_vars_print_size = min(display_max_rows, len_summary) assert len_summary == data_vars_print_size + + with xr.set_options( + display_expand_coords=False, + display_expand_data_vars=False, + display_expand_attrs=False, + ): + actual = formatting.dataset_repr(ds) + expected = dedent( + f"""\ + + Dimensions: (time: 2) + Coordinates: (1) + Data variables: ({n_vars}) + Attributes: ({n_attr})""" + ) + assert actual == expected diff --git a/xarray/tests/test_formatting_html.py b/xarray/tests/test_formatting_html.py index 9a210ad6fa3..09c6fa0cf3c 100644 --- a/xarray/tests/test_formatting_html.py +++ b/xarray/tests/test_formatting_html.py @@ -1,5 +1,3 @@ -from distutils.version import LooseVersion - import numpy as np import pandas as pd import pytest @@ -57,19 +55,9 @@ def test_short_data_repr_html_non_str_keys(dataset): def test_short_data_repr_html_dask(dask_dataarray): - import dask - - if LooseVersion(dask.__version__) < "2.0.0": - assert not hasattr(dask_dataarray.data, "_repr_html_") - data_repr = fh.short_data_repr_html(dask_dataarray) - assert ( - data_repr - == "dask.array<xarray-<this-array>, shape=(4, 6), dtype=float64, chunksize=(4, 6)>" - ) - else: - assert hasattr(dask_dataarray.data, "_repr_html_") - data_repr = fh.short_data_repr_html(dask_dataarray) - assert data_repr == dask_dataarray.data._repr_html_() + assert hasattr(dask_dataarray.data, "_repr_html_") + data_repr = fh.short_data_repr_html(dask_dataarray) + assert data_repr == dask_dataarray.data._repr_html_() def test_format_dims_no_dims(): @@ -115,6 +103,17 @@ def test_repr_of_dataarray(dataarray): formatted.count("class='xr-section-summary-in' type='checkbox' disabled >") == 2 ) + with xr.set_options(display_expand_data=False): + formatted = fh.array_repr(dataarray) + assert "dim_0" in formatted + # has an expanded data section + assert formatted.count("class='xr-array-in' type='checkbox' checked>") == 0 + # coords and attrs don't have an items so they'll be be disabled and collapsed + assert ( + formatted.count("class='xr-section-summary-in' type='checkbox' disabled >") + == 2 + ) + def test_summary_of_multiindex_coord(multiindex): idx = multiindex.x.variable.to_index_variable() @@ -138,6 +137,20 @@ def test_repr_of_dataset(dataset): assert "<U4" in formatted or ">U4" in formatted assert "<IA>" in formatted + with xr.set_options( + display_expand_coords=False, + display_expand_data_vars=False, + display_expand_attrs=False, + ): + formatted = fh.dataset_repr(dataset) + # coords, attrs, and data_vars are collapsed + assert ( + formatted.count("class='xr-section-summary-in' type='checkbox' checked>") + == 0 + ) + assert "<U4" in formatted or ">U4" in formatted + assert "<IA>" in formatted + def test_repr_text_fallback(dataset): formatted = fh.dataset_repr(dataset) @@ -156,3 +169,29 @@ def test_variable_repr_html(): # Just test that something reasonable was produced. assert html.startswith("") assert "xarray.Variable" in html + + +def test_repr_of_nonstr_dataset(dataset): + ds = dataset.copy() + ds.attrs[1] = "Test value" + ds[2] = ds["tmin"] + formatted = fh.dataset_repr(ds) + assert "
    1 :
    Test value
    " in formatted + assert "
    2" in formatted + + +def test_repr_of_nonstr_dataarray(dataarray): + da = dataarray.rename(dim_0=15) + da.attrs[1] = "value" + formatted = fh.array_repr(da) + assert "
    1 :
    value
    " in formatted + assert "
  • 15: 4
  • " in formatted + + +def test_nonstr_variable_repr_html(): + v = xr.Variable(["time", 10], [[1, 2, 3], [4, 5, 6]], {22: "bar"}) + assert hasattr(v, "_repr_html_") + with xr.set_options(display_style="html"): + html = v._repr_html_().strip() + assert "
    22 :
    bar
    " in html + assert "
  • 10: 3
  • " in html diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 1366dd417f7..b2510141d78 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3,9 +3,10 @@ import pytest import xarray as xr +from xarray import DataArray, Dataset, Variable from xarray.core.groupby import _consolidate_slices -from . import assert_allclose, assert_equal, assert_identical, raises_regex +from . import assert_allclose, assert_equal, assert_identical, create_test_data @pytest.fixture @@ -387,8 +388,8 @@ def test_da_groupby_assign_coords(): @pytest.mark.parametrize("obj", [repr_da, repr_da.to_dataset(name="a")]) def test_groupby_repr(obj, dim): actual = repr(obj.groupby(dim)) - expected = "%sGroupBy" % obj.__class__.__name__ - expected += ", grouped over %r " % dim + expected = f"{obj.__class__.__name__}GroupBy" + expected += ", grouped over %r" % dim expected += "\n%r groups with labels " % (len(np.unique(obj[dim]))) if dim == "x": expected += "1, 2, 3, 4, 5." @@ -404,8 +405,8 @@ def test_groupby_repr(obj, dim): @pytest.mark.parametrize("obj", [repr_da, repr_da.to_dataset(name="a")]) def test_groupby_repr_datetime(obj): actual = repr(obj.groupby("t.month")) - expected = "%sGroupBy" % obj.__class__.__name__ - expected += ", grouped over 'month' " + expected = f"{obj.__class__.__name__}GroupBy" + expected += ", grouped over 'month'" expected += "\n%r groups with labels " % (len(np.unique(obj.t.dt.month))) expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12." assert actual == expected @@ -479,34 +480,38 @@ def test_groupby_drops_nans(): def test_groupby_grouping_errors(): dataset = xr.Dataset({"foo": ("x", [1, 1, 1])}, {"x": [1, 2, 3]}) - with raises_regex(ValueError, "None of the data falls within bins with edges"): + with pytest.raises( + ValueError, match=r"None of the data falls within bins with edges" + ): dataset.groupby_bins("x", bins=[0.1, 0.2, 0.3]) - with raises_regex(ValueError, "None of the data falls within bins with edges"): + with pytest.raises( + ValueError, match=r"None of the data falls within bins with edges" + ): dataset.to_array().groupby_bins("x", bins=[0.1, 0.2, 0.3]) - with raises_regex(ValueError, "All bin edges are NaN."): + with pytest.raises(ValueError, match=r"All bin edges are NaN."): dataset.groupby_bins("x", bins=[np.nan, np.nan, np.nan]) - with raises_regex(ValueError, "All bin edges are NaN."): + with pytest.raises(ValueError, match=r"All bin edges are NaN."): dataset.to_array().groupby_bins("x", bins=[np.nan, np.nan, np.nan]) - with raises_regex(ValueError, "Failed to group data."): + with pytest.raises(ValueError, match=r"Failed to group data."): dataset.groupby(dataset.foo * np.nan) - with raises_regex(ValueError, "Failed to group data."): + with pytest.raises(ValueError, match=r"Failed to group data."): dataset.to_array().groupby(dataset.foo * np.nan) def test_groupby_reduce_dimension_error(array): grouped = array.groupby("y") - with raises_regex(ValueError, "cannot reduce over dimensions"): + with pytest.raises(ValueError, match=r"cannot reduce over dimensions"): grouped.mean() - with raises_regex(ValueError, "cannot reduce over dimensions"): + with pytest.raises(ValueError, match=r"cannot reduce over dimensions"): grouped.mean("huh") - with raises_regex(ValueError, "cannot reduce over dimensions"): + with pytest.raises(ValueError, match=r"cannot reduce over dimensions"): grouped.mean(("x", "y", "asd")) grouped = array.groupby("y", squeeze=False) @@ -549,7 +554,7 @@ def test_groupby_none_group_name(): assert "group" in mean.dims -def test_groupby_sel(dataset): +def test_groupby_getitem(dataset): assert_identical(dataset.sel(x="a"), dataset.groupby("x")["a"]) assert_identical(dataset.sel(z=1), dataset.groupby("z")[1]) @@ -562,4 +567,178 @@ def test_groupby_sel(dataset): assert_identical(expected, actual) +def test_groupby_dataset(): + data = Dataset( + {"z": (["x", "y"], np.random.randn(3, 5))}, + {"x": ("x", list("abc")), "c": ("x", [0, 1, 0]), "y": range(5)}, + ) + groupby = data.groupby("x") + assert len(groupby) == 3 + expected_groups = {"a": 0, "b": 1, "c": 2} + assert groupby.groups == expected_groups + expected_items = [ + ("a", data.isel(x=0)), + ("b", data.isel(x=1)), + ("c", data.isel(x=2)), + ] + for actual, expected in zip(groupby, expected_items): + assert actual[0] == expected[0] + assert_equal(actual[1], expected[1]) + + def identity(x): + return x + + for k in ["x", "c", "y"]: + actual = data.groupby(k, squeeze=False).map(identity) + assert_equal(data, actual) + + +def test_groupby_dataset_returns_new_type(): + data = Dataset({"z": (["x", "y"], np.random.randn(3, 5))}) + + actual = data.groupby("x").map(lambda ds: ds["z"]) + expected = data["z"] + assert_identical(expected, actual) + + actual = data["z"].groupby("x").map(lambda x: x.to_dataset()) + expected = data + assert_identical(expected, actual) + + +def test_groupby_dataset_iter(): + data = create_test_data() + for n, (t, sub) in enumerate(list(data.groupby("dim1"))[:3]): + assert data["dim1"][n] == t + assert_equal(data["var1"][n], sub["var1"]) + assert_equal(data["var2"][n], sub["var2"]) + assert_equal(data["var3"][:, n], sub["var3"]) + + +def test_groupby_dataset_errors(): + data = create_test_data() + with pytest.raises(TypeError, match=r"`group` must be"): + data.groupby(np.arange(10)) + with pytest.raises(ValueError, match=r"length does not match"): + data.groupby(data["dim1"][:3]) + with pytest.raises(TypeError, match=r"`group` must be"): + data.groupby(data.coords["dim1"].to_index()) + + +def test_groupby_dataset_reduce(): + data = Dataset( + { + "xy": (["x", "y"], np.random.randn(3, 4)), + "xonly": ("x", np.random.randn(3)), + "yonly": ("y", np.random.randn(4)), + "letters": ("y", ["a", "a", "b", "b"]), + } + ) + + expected = data.mean("y") + expected["yonly"] = expected["yonly"].variable.set_dims({"x": 3}) + actual = data.groupby("x").mean(...) + assert_allclose(expected, actual) + + actual = data.groupby("x").mean("y") + assert_allclose(expected, actual) + + letters = data["letters"] + expected = Dataset( + { + "xy": data["xy"].groupby(letters).mean(...), + "xonly": (data["xonly"].mean().variable.set_dims({"letters": 2})), + "yonly": data["yonly"].groupby(letters).mean(), + } + ) + actual = data.groupby("letters").mean(...) + assert_allclose(expected, actual) + + +def test_groupby_dataset_math(): + def reorder_dims(x): + return x.transpose("dim1", "dim2", "dim3", "time") + + ds = create_test_data() + ds["dim1"] = ds["dim1"] + for squeeze in [True, False]: + grouped = ds.groupby("dim1", squeeze=squeeze) + + expected = reorder_dims(ds + ds.coords["dim1"]) + actual = grouped + ds.coords["dim1"] + assert_identical(expected, reorder_dims(actual)) + + actual = ds.coords["dim1"] + grouped + assert_identical(expected, reorder_dims(actual)) + + ds2 = 2 * ds + expected = reorder_dims(ds + ds2) + actual = grouped + ds2 + assert_identical(expected, reorder_dims(actual)) + + actual = ds2 + grouped + assert_identical(expected, reorder_dims(actual)) + + grouped = ds.groupby("numbers") + zeros = DataArray([0, 0, 0, 0], [("numbers", range(4))]) + expected = (ds + Variable("dim3", np.zeros(10))).transpose( + "dim3", "dim1", "dim2", "time" + ) + actual = grouped + zeros + assert_equal(expected, actual) + + actual = zeros + grouped + assert_equal(expected, actual) + + with pytest.raises(ValueError, match=r"incompat.* grouped binary"): + grouped + ds + with pytest.raises(ValueError, match=r"incompat.* grouped binary"): + ds + grouped + with pytest.raises(TypeError, match=r"only support binary ops"): + grouped + 1 + with pytest.raises(TypeError, match=r"only support binary ops"): + grouped + grouped + with pytest.raises(TypeError, match=r"in-place operations"): + ds += grouped + + ds = Dataset( + { + "x": ("time", np.arange(100)), + "time": pd.date_range("2000-01-01", periods=100), + } + ) + with pytest.raises(ValueError, match=r"incompat.* grouped binary"): + ds + ds.groupby("time.month") + + +def test_groupby_dataset_math_virtual(): + ds = Dataset({"x": ("t", [1, 2, 3])}, {"t": pd.date_range("20100101", periods=3)}) + grouped = ds.groupby("t.day") + actual = grouped - grouped.mean(...) + expected = Dataset({"x": ("t", [0, 0, 0])}, ds[["t", "t.day"]]) + assert_identical(actual, expected) + + +def test_groupby_dataset_nan(): + # nan should be excluded from groupby + ds = Dataset({"foo": ("x", [1, 2, 3, 4])}, {"bar": ("x", [1, 1, 2, np.nan])}) + actual = ds.groupby("bar").mean(...) + expected = Dataset({"foo": ("bar", [1.5, 3]), "bar": [1, 2]}) + assert_identical(actual, expected) + + +def test_groupby_dataset_order(): + # groupby should preserve variables order + ds = Dataset() + for vn in ["a", "b", "c"]: + ds[vn] = DataArray(np.arange(10), dims=["t"]) + data_vars_ref = list(ds.data_vars.keys()) + ds = ds.groupby("t").mean(...) + data_vars = list(ds.data_vars.keys()) + assert data_vars == data_vars_ref + # coords are now at the end of the list, so the test below fails + # all_vars = list(ds.variables.keys()) + # all_vars_ref = list(ds.variables.keys()) + # .assertEqual(all_vars, all_vars_ref) + + # TODO: move other groupby tests from test_dataset and test_dataarray over here diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py new file mode 100644 index 00000000000..c8ba72a253f --- /dev/null +++ b/xarray/tests/test_indexes.py @@ -0,0 +1,196 @@ +import numpy as np +import pandas as pd +import pytest + +import xarray as xr +from xarray.core.indexes import PandasIndex, PandasMultiIndex, _asarray_tuplesafe +from xarray.core.variable import IndexVariable + + +def test_asarray_tuplesafe(): + res = _asarray_tuplesafe(("a", 1)) + assert isinstance(res, np.ndarray) + assert res.ndim == 0 + assert res.item() == ("a", 1) + + res = _asarray_tuplesafe([(0,), (1,)]) + assert res.shape == (2,) + assert res[0] == (0,) + assert res[1] == (1,) + + +class TestPandasIndex: + def test_constructor(self): + pd_idx = pd.Index([1, 2, 3]) + index = PandasIndex(pd_idx, "x") + + assert index.index is pd_idx + assert index.dim == "x" + + def test_from_variables(self): + var = xr.Variable( + "x", [1, 2, 3], attrs={"unit": "m"}, encoding={"dtype": np.int32} + ) + + index, index_vars = PandasIndex.from_variables({"x": var}) + xr.testing.assert_identical(var.to_index_variable(), index_vars["x"]) + assert index.dim == "x" + assert index.index.equals(index_vars["x"].to_index()) + + var2 = xr.Variable(("x", "y"), [[1, 2, 3], [4, 5, 6]]) + with pytest.raises(ValueError, match=r".*only accepts one variable.*"): + PandasIndex.from_variables({"x": var, "foo": var2}) + + with pytest.raises( + ValueError, match=r".*only accepts a 1-dimensional variable.*" + ): + PandasIndex.from_variables({"foo": var2}) + + def test_from_pandas_index(self): + pd_idx = pd.Index([1, 2, 3], name="foo") + + index, index_vars = PandasIndex.from_pandas_index(pd_idx, "x") + + assert index.dim == "x" + assert index.index is pd_idx + assert index.index.name == "foo" + xr.testing.assert_identical(index_vars["foo"], IndexVariable("x", [1, 2, 3])) + + # test no name set for pd.Index + pd_idx.name = None + index, index_vars = PandasIndex.from_pandas_index(pd_idx, "x") + assert "x" in index_vars + assert index.index is not pd_idx + assert index.index.name == "x" + + def to_pandas_index(self): + pd_idx = pd.Index([1, 2, 3], name="foo") + index = PandasIndex(pd_idx, "x") + assert index.to_pandas_index() is pd_idx + + def test_query(self): + # TODO: add tests that aren't just for edge cases + index = PandasIndex(pd.Index([1, 2, 3]), "x") + with pytest.raises(KeyError, match=r"not all values found"): + index.query({"x": [0]}) + with pytest.raises(KeyError): + index.query({"x": 0}) + with pytest.raises(ValueError, match=r"does not have a MultiIndex"): + index.query({"x": {"one": 0}}) + + def test_query_datetime(self): + index = PandasIndex( + pd.to_datetime(["2000-01-01", "2001-01-01", "2002-01-01"]), "x" + ) + actual = index.query({"x": "2001-01-01"}) + expected = (1, None) + assert actual == expected + + actual = index.query({"x": index.to_pandas_index().to_numpy()[1]}) + assert actual == expected + + def test_query_unsorted_datetime_index_raises(self): + index = PandasIndex(pd.to_datetime(["2001", "2000", "2002"]), "x") + with pytest.raises(KeyError): + # pandas will try to convert this into an array indexer. We should + # raise instead, so we can be sure the result of indexing with a + # slice is always a view. + index.query({"x": slice("2001", "2002")}) + + def test_equals(self): + index1 = PandasIndex([1, 2, 3], "x") + index2 = PandasIndex([1, 2, 3], "x") + assert index1.equals(index2) is True + + def test_union(self): + index1 = PandasIndex([1, 2, 3], "x") + index2 = PandasIndex([4, 5, 6], "y") + actual = index1.union(index2) + assert actual.index.equals(pd.Index([1, 2, 3, 4, 5, 6])) + assert actual.dim == "x" + + def test_intersection(self): + index1 = PandasIndex([1, 2, 3], "x") + index2 = PandasIndex([2, 3, 4], "y") + actual = index1.intersection(index2) + assert actual.index.equals(pd.Index([2, 3])) + assert actual.dim == "x" + + def test_copy(self): + expected = PandasIndex([1, 2, 3], "x") + actual = expected.copy() + + assert actual.index.equals(expected.index) + assert actual.index is not expected.index + assert actual.dim == expected.dim + + def test_getitem(self): + pd_idx = pd.Index([1, 2, 3]) + expected = PandasIndex(pd_idx, "x") + actual = expected[1:] + + assert actual.index.equals(pd_idx[1:]) + assert actual.dim == expected.dim + + +class TestPandasMultiIndex: + def test_from_variables(self): + v_level1 = xr.Variable( + "x", [1, 2, 3], attrs={"unit": "m"}, encoding={"dtype": np.int32} + ) + v_level2 = xr.Variable( + "x", ["a", "b", "c"], attrs={"unit": "m"}, encoding={"dtype": "U"} + ) + + index, index_vars = PandasMultiIndex.from_variables( + {"level1": v_level1, "level2": v_level2} + ) + + expected_idx = pd.MultiIndex.from_arrays([v_level1.data, v_level2.data]) + assert index.dim == "x" + assert index.index.equals(expected_idx) + + assert list(index_vars) == ["x", "level1", "level2"] + xr.testing.assert_equal(xr.IndexVariable("x", expected_idx), index_vars["x"]) + xr.testing.assert_identical(v_level1.to_index_variable(), index_vars["level1"]) + xr.testing.assert_identical(v_level2.to_index_variable(), index_vars["level2"]) + + var = xr.Variable(("x", "y"), [[1, 2, 3], [4, 5, 6]]) + with pytest.raises( + ValueError, match=r".*only accepts 1-dimensional variables.*" + ): + PandasMultiIndex.from_variables({"var": var}) + + v_level3 = xr.Variable("y", [4, 5, 6]) + with pytest.raises(ValueError, match=r"unmatched dimensions for variables.*"): + PandasMultiIndex.from_variables({"level1": v_level1, "level3": v_level3}) + + def test_from_pandas_index(self): + pd_idx = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]], names=("foo", "bar")) + + index, index_vars = PandasMultiIndex.from_pandas_index(pd_idx, "x") + + assert index.dim == "x" + assert index.index is pd_idx + assert index.index.names == ("foo", "bar") + xr.testing.assert_identical(index_vars["x"], IndexVariable("x", pd_idx)) + xr.testing.assert_identical(index_vars["foo"], IndexVariable("x", [1, 2, 3])) + xr.testing.assert_identical(index_vars["bar"], IndexVariable("x", [4, 5, 6])) + + def test_query(self): + index = PandasMultiIndex( + pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")), "x" + ) + # test tuples inside slice are considered as scalar indexer values + assert index.query({"x": slice(("a", 1), ("b", 2))}) == (slice(0, 4), None) + + with pytest.raises(KeyError, match=r"not all values found"): + index.query({"x": [0]}) + with pytest.raises(KeyError): + index.query({"x": 0}) + with pytest.raises(ValueError, match=r"cannot provide labels for both.*"): + index.query({"one": 0, "x": "a"}) + with pytest.raises(ValueError, match=r"invalid multi-index level names"): + index.query({"x": {"three": 0}}) + with pytest.raises(IndexError): + index.query({"x": (slice(None), 1, "no_level")}) diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 4ef7536e1f2..6e4fd320029 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -7,7 +7,7 @@ from xarray import DataArray, Dataset, Variable from xarray.core import indexing, nputils -from . import IndexerMaker, ReturnItem, assert_array_equal, raises_regex +from . import IndexerMaker, ReturnItem, assert_array_equal B = IndexerMaker(indexing.BasicIndexer) @@ -37,20 +37,9 @@ def test_expanded_indexer(self): j = indexing.expanded_indexer(i, x.ndim) assert_array_equal(x[i], x[j]) assert_array_equal(self.set_to_zero(x, i), self.set_to_zero(x, j)) - with raises_regex(IndexError, "too many indices"): + with pytest.raises(IndexError, match=r"too many indices"): indexing.expanded_indexer(arr[1, 2, 3], 2) - def test_asarray_tuplesafe(self): - res = indexing._asarray_tuplesafe(("a", 1)) - assert isinstance(res, np.ndarray) - assert res.ndim == 0 - assert res.item() == ("a", 1) - - res = indexing._asarray_tuplesafe([(0,), (1,)]) - assert res.shape == (2,) - assert res[0] == (0,) - assert res[1] == (1,) - def test_stacked_multiindex_min_max(self): data = np.random.randn(3, 23, 4) da = DataArray( @@ -66,64 +55,38 @@ def test_stacked_multiindex_min_max(self): assert_array_equal(da2.loc["a", s.max()], data[2, 22, 0]) assert_array_equal(da2.loc["b", s.min()], data[0, 0, 1]) - def test_convert_label_indexer(self): - # TODO: add tests that aren't just for edge cases - index = pd.Index([1, 2, 3]) - with raises_regex(KeyError, "not all values found"): - indexing.convert_label_indexer(index, [0]) - with pytest.raises(KeyError): - indexing.convert_label_indexer(index, 0) - with raises_regex(ValueError, "does not have a MultiIndex"): - indexing.convert_label_indexer(index, {"one": 0}) - - mindex = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")) - with raises_regex(KeyError, "not all values found"): - indexing.convert_label_indexer(mindex, [0]) - with pytest.raises(KeyError): - indexing.convert_label_indexer(mindex, 0) - with pytest.raises(ValueError): - indexing.convert_label_indexer(index, {"three": 0}) - with pytest.raises(IndexError): - indexing.convert_label_indexer(mindex, (slice(None), 1, "no_level")) - - def test_convert_label_indexer_datetime(self): - index = pd.to_datetime(["2000-01-01", "2001-01-01", "2002-01-01"]) - actual = indexing.convert_label_indexer(index, "2001-01-01") - expected = (1, None) - assert actual == expected - - actual = indexing.convert_label_indexer(index, index.to_numpy()[1]) - assert actual == expected - - def test_convert_unsorted_datetime_index_raises(self): - index = pd.to_datetime(["2001", "2000", "2002"]) - with pytest.raises(KeyError): - # pandas will try to convert this into an array indexer. We should - # raise instead, so we can be sure the result of indexing with a - # slice is always a view. - indexing.convert_label_indexer(index, slice("2001", "2002")) - - def test_get_dim_indexers(self): + def test_group_indexers_by_index(self): mindex = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")) - mdata = DataArray(range(4), [("x", mindex)]) - - dim_indexers = indexing.get_dim_indexers(mdata, {"one": "a", "two": 1}) - assert dim_indexers == {"x": {"one": "a", "two": 1}} - - with raises_regex(ValueError, "cannot combine"): - indexing.get_dim_indexers(mdata, {"x": "a", "two": 1}) - - with raises_regex(ValueError, "do not exist"): - indexing.get_dim_indexers(mdata, {"y": "a"}) + data = DataArray( + np.zeros((4, 2, 2)), coords={"x": mindex, "y": [1, 2]}, dims=("x", "y", "z") + ) + data.coords["y2"] = ("y", [2.0, 3.0]) - with raises_regex(ValueError, "do not exist"): - indexing.get_dim_indexers(mdata, {"four": 1}) + indexes, grouped_indexers = indexing.group_indexers_by_index( + data, {"z": 0, "one": "a", "two": 1, "y": 0} + ) + assert indexes == {"x": data.xindexes["x"], "y": data.xindexes["y"]} + assert grouped_indexers == { + "x": {"one": "a", "two": 1}, + "y": {"y": 0}, + None: {"z": 0}, + } + + with pytest.raises(KeyError, match=r"no index found for coordinate y2"): + indexing.group_indexers_by_index(data, {"y2": 2.0}) + with pytest.raises(KeyError, match=r"w is not a valid dimension or coordinate"): + indexing.group_indexers_by_index(data, {"w": "a"}) + with pytest.raises(ValueError, match=r"cannot supply.*"): + indexing.group_indexers_by_index(data, {"z": 1}, method="nearest") def test_remap_label_indexers(self): def test_indexer(data, x, expected_pos, expected_idx=None): - pos, idx = indexing.remap_label_indexers(data, {"x": x}) + pos, new_idx_vars = indexing.remap_label_indexers(data, {"x": x}) + idx, _ = new_idx_vars.get("x", (None, None)) + if idx is not None: + idx = idx.to_pandas_index() assert_array_equal(pos.get("x"), expected_pos) - assert_array_equal(idx.get("x"), expected_idx) + assert_array_equal(idx, expected_idx) data = Dataset({"x": ("x", [1, 2, 3])}) mindex = pd.MultiIndex.from_product( @@ -224,7 +187,7 @@ def test_lazily_indexed_array(self): original = np.random.rand(10, 20, 30) x = indexing.NumpyIndexingAdapter(original) v = Variable(["i", "j", "k"], original) - lazy = indexing.LazilyOuterIndexedArray(x) + lazy = indexing.LazilyIndexedArray(x) v_lazy = Variable(["i", "j", "k"], lazy) arr = ReturnItem() # test orthogonally applied indexers @@ -244,9 +207,7 @@ def test_lazily_indexed_array(self): ]: assert expected.shape == actual.shape assert_array_equal(expected, actual) - assert isinstance( - actual._data, indexing.LazilyOuterIndexedArray - ) + assert isinstance(actual._data, indexing.LazilyIndexedArray) # make sure actual.key is appropriate type if all( @@ -282,18 +243,18 @@ def test_lazily_indexed_array(self): actual._data, ( indexing.LazilyVectorizedIndexedArray, - indexing.LazilyOuterIndexedArray, + indexing.LazilyIndexedArray, ), ) - assert isinstance(actual._data, indexing.LazilyOuterIndexedArray) + assert isinstance(actual._data, indexing.LazilyIndexedArray) assert isinstance(actual._data.array, indexing.NumpyIndexingAdapter) def test_vectorized_lazily_indexed_array(self): original = np.random.rand(10, 20, 30) x = indexing.NumpyIndexingAdapter(original) v_eager = Variable(["i", "j", "k"], x) - lazy = indexing.LazilyOuterIndexedArray(x) + lazy = indexing.LazilyIndexedArray(x) v_lazy = Variable(["i", "j", "k"], lazy) arr = ReturnItem() @@ -306,7 +267,7 @@ def check_indexing(v_eager, v_lazy, indexers): actual._data, ( indexing.LazilyVectorizedIndexedArray, - indexing.LazilyOuterIndexedArray, + indexing.LazilyIndexedArray, ), ) assert_array_equal(expected, actual) @@ -364,19 +325,19 @@ def test_index_scalar(self): class TestMemoryCachedArray: def test_wrapper(self): - original = indexing.LazilyOuterIndexedArray(np.arange(10)) + original = indexing.LazilyIndexedArray(np.arange(10)) wrapped = indexing.MemoryCachedArray(original) assert_array_equal(wrapped, np.arange(10)) assert isinstance(wrapped.array, indexing.NumpyIndexingAdapter) def test_sub_array(self): - original = indexing.LazilyOuterIndexedArray(np.arange(10)) + original = indexing.LazilyIndexedArray(np.arange(10)) wrapped = indexing.MemoryCachedArray(original) child = wrapped[B[:5]] assert isinstance(child, indexing.MemoryCachedArray) assert_array_equal(child, np.arange(5)) assert isinstance(child.array, indexing.NumpyIndexingAdapter) - assert isinstance(wrapped.array, indexing.LazilyOuterIndexedArray) + assert isinstance(wrapped.array, indexing.LazilyIndexedArray) def test_setitem(self): original = np.arange(10) @@ -471,7 +432,7 @@ def test_vectorized_indexer(): check_slice(indexing.VectorizedIndexer) check_array1d(indexing.VectorizedIndexer) check_array2d(indexing.VectorizedIndexer) - with raises_regex(ValueError, "numbers of dimensions"): + with pytest.raises(ValueError, match=r"numbers of dimensions"): indexing.VectorizedIndexer( (np.array(1, dtype=np.int64), np.arange(5, dtype=np.int64)) ) @@ -736,7 +697,7 @@ def test_create_mask_dask(): def test_create_mask_error(): - with raises_regex(TypeError, "unexpected key type"): + with pytest.raises(TypeError, match=r"unexpected key type"): indexing.create_mask((1, 2), (3, 4)) @@ -755,3 +716,16 @@ def test_create_mask_error(): def test_posify_mask_subindexer(indices, expected): actual = indexing._posify_mask_subindexer(indices) np.testing.assert_array_equal(expected, actual) + + +def test_indexing_1d_object_array(): + items = (np.arange(3), np.arange(6)) + arr = DataArray(np.array(items, dtype=object)) + + actual = arr[0] + + expected_data = np.empty((), dtype=object) + expected_data[()] = items[0] + expected = DataArray(expected_data) + + assert [actual.data.item()] == [expected.data.item()] diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 20d5fb12a62..2029e6af05b 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -190,7 +190,7 @@ def func(obj, dim, new_x): "w": xdest["w"], "z2": xdest["z2"], "y": da["y"], - "x": (("z", "w"), xdest), + "x": (("z", "w"), xdest.data), "x2": (("z", "w"), func(da["x2"], "x", xdest)), }, ) @@ -416,15 +416,19 @@ def test_errors(use_dask): @requires_scipy def test_dtype(): - ds = xr.Dataset( - {"var1": ("x", [0, 1, 2]), "var2": ("x", ["a", "b", "c"])}, - coords={"x": [0.1, 0.2, 0.3], "z": ("x", ["a", "b", "c"])}, + data_vars = dict( + a=("time", np.array([1, 1.25, 2])), + b=("time", np.array([True, True, False], dtype=bool)), + c=("time", np.array(["start", "start", "end"], dtype=str)), + ) + time = np.array([0, 0.25, 1], dtype=float) + expected = xr.Dataset(data_vars, coords=dict(time=time)) + actual = xr.Dataset( + {k: (dim, arr[[0, -1]]) for k, (dim, arr) in data_vars.items()}, + coords=dict(time=time[[0, -1]]), ) - actual = ds.interp(x=[0.15, 0.25]) - assert "var1" in actual - assert "var2" not in actual - # object array should be dropped - assert "z" not in actual.coords + actual = actual.interp(time=time, method="linear") + assert_identical(expected, actual) @requires_scipy @@ -495,7 +499,7 @@ def test_dataset(): @pytest.mark.parametrize("case", [0, 3]) def test_interpolate_dimorder(case): - """ Make sure the resultant dimension order is consistent with .sel() """ + """Make sure the resultant dimension order is consistent with .sel()""" if not has_scipy: pytest.skip("scipy is not installed.") @@ -723,6 +727,7 @@ def test_datetime_interp_noerror(): @requires_cftime +@requires_scipy def test_3641(): times = xr.cftime_range("0001", periods=3, freq="500Y") da = xr.DataArray(range(3), dims=["time"], coords=[times]) @@ -825,6 +830,7 @@ def test_interpolate_chunk_1d(method, data_ndim, interp_ndim, nscalar, chunked): @requires_scipy @requires_dask @pytest.mark.parametrize("method", ["linear", "nearest"]) +@pytest.mark.filterwarnings("ignore:Increasing number of chunks") def test_interpolate_chunk_advanced(method): """Interpolate nd array with an nd indexer sharing coordinates.""" # Create original array @@ -866,3 +872,38 @@ def test_interpolate_chunk_advanced(method): z = z.chunk(3) actual = da.interp(t=0.5, x=x, y=y, z=z, kwargs=kwargs, method=method) assert_identical(actual, expected) + + +@requires_scipy +def test_interp1d_bounds_error(): + """Ensure exception on bounds error is raised if requested""" + da = xr.DataArray( + np.sin(0.3 * np.arange(4)), + [("time", np.arange(4))], + ) + + with pytest.raises(ValueError): + da.interp(time=3.5, kwargs=dict(bounds_error=True)) + + # default is to fill with nans, so this should pass + da.interp(time=3.5) + + +@requires_scipy +@pytest.mark.parametrize( + "x, expect_same_attrs", + [ + (2.5, True), + (np.array([2.5, 5]), True), + (("x", np.array([0, 0.5, 1, 2]), dict(unit="s")), False), + ], +) +def test_coord_attrs(x, expect_same_attrs): + base_attrs = dict(foo="bar") + ds = xr.Dataset( + data_vars=dict(a=2 * np.arange(5)), + coords={"x": ("x", np.arange(5), base_attrs)}, + ) + + has_same_attrs = ds.interp(x=x).x.attrs == base_attrs + assert expect_same_attrs == has_same_attrs diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py index 34b138e1f6a..555a29b1952 100644 --- a/xarray/tests/test_merge.py +++ b/xarray/tests/test_merge.py @@ -6,7 +6,6 @@ from xarray.core.merge import MergeError from xarray.testing import assert_equal, assert_identical -from . import raises_regex from .test_dataset import create_test_data @@ -30,13 +29,14 @@ def test_broadcast_dimension_size(self): class TestMergeFunction: def test_merge_arrays(self): - data = create_test_data() + data = create_test_data(add_attrs=False) + actual = xr.merge([data.var1, data.var2]) expected = data[["var1", "var2"]] assert_identical(actual, expected) def test_merge_datasets(self): - data = create_test_data() + data = create_test_data(add_attrs=False) actual = xr.merge([data[["var1"]], data[["var2"]]]) expected = data[["var1", "var2"]] @@ -47,20 +47,23 @@ def test_merge_datasets(self): def test_merge_dataarray_unnamed(self): data = xr.DataArray([1, 2], dims="x") - with raises_regex(ValueError, "without providing an explicit name"): + with pytest.raises(ValueError, match=r"without providing an explicit name"): xr.merge([data]) def test_merge_arrays_attrs_default(self): var1_attrs = {"a": 1, "b": 2} var2_attrs = {"a": 1, "c": 3} - expected_attrs = {} + expected_attrs = {"a": 1, "b": 2} + + data = create_test_data(add_attrs=False) + expected = data[["var1", "var2"]].copy() + expected.var1.attrs = var1_attrs + expected.var2.attrs = var2_attrs + expected.attrs = expected_attrs - data = create_test_data() data.var1.attrs = var1_attrs data.var2.attrs = var2_attrs actual = xr.merge([data.var1, data.var2]) - expected = data[["var1", "var2"]] - expected.attrs = expected_attrs assert_identical(actual, expected) @pytest.mark.parametrize( @@ -92,21 +95,109 @@ def test_merge_arrays_attrs_default(self): {"a": 1, "b": 2}, False, ), + ( + "drop_conflicts", + {"a": 1, "b": 2, "c": 3}, + {"b": 1, "c": 3, "d": 4}, + {"a": 1, "c": 3, "d": 4}, + False, + ), + ( + "drop_conflicts", + {"a": 1, "b": np.array([2]), "c": np.array([3])}, + {"b": 1, "c": np.array([3]), "d": 4}, + {"a": 1, "c": np.array([3]), "d": 4}, + False, + ), + ( + lambda attrs, context: attrs[1], + {"a": 1, "b": 2, "c": 3}, + {"a": 4, "b": 3, "c": 1}, + {"a": 4, "b": 3, "c": 1}, + False, + ), ], ) def test_merge_arrays_attrs( self, combine_attrs, var1_attrs, var2_attrs, expected_attrs, expect_exception ): - data = create_test_data() - data.var1.attrs = var1_attrs - data.var2.attrs = var2_attrs + data1 = xr.Dataset(attrs=var1_attrs) + data2 = xr.Dataset(attrs=var2_attrs) if expect_exception: - with raises_regex(MergeError, "combine_attrs"): - actual = xr.merge([data.var1, data.var2], combine_attrs=combine_attrs) + with pytest.raises(MergeError, match="combine_attrs"): + actual = xr.merge([data1, data2], combine_attrs=combine_attrs) else: - actual = xr.merge([data.var1, data.var2], combine_attrs=combine_attrs) - expected = data[["var1", "var2"]] - expected.attrs = expected_attrs + actual = xr.merge([data1, data2], combine_attrs=combine_attrs) + expected = xr.Dataset(attrs=expected_attrs) + + assert_identical(actual, expected) + + @pytest.mark.parametrize( + "combine_attrs, attrs1, attrs2, expected_attrs, expect_exception", + [ + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 1, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + False, + ), + ("no_conflicts", {"a": 1, "b": 2}, {}, {"a": 1, "b": 2}, False), + ("no_conflicts", {}, {"a": 1, "c": 3}, {"a": 1, "c": 3}, False), + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 4, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + True, + ), + ("drop", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "b": 2}, {"a": 1, "b": 2}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {"a": 1, "b": 2}, True), + ( + "override", + {"a": 1, "b": 2}, + {"a": 4, "b": 5, "c": 3}, + {"a": 1, "b": 2}, + False, + ), + ( + "drop_conflicts", + {"a": 1, "b": 2, "c": 3}, + {"b": 1, "c": 3, "d": 4}, + {"a": 1, "c": 3, "d": 4}, + False, + ), + ( + lambda attrs, context: attrs[1], + {"a": 1, "b": 2, "c": 3}, + {"a": 4, "b": 3, "c": 1}, + {"a": 4, "b": 3, "c": 1}, + False, + ), + ], + ) + def test_merge_arrays_attrs_variables( + self, combine_attrs, attrs1, attrs2, expected_attrs, expect_exception + ): + """check that combine_attrs is used on data variables and coords""" + data1 = xr.Dataset( + {"var1": ("dim1", [], attrs1)}, coords={"dim1": ("dim1", [], attrs1)} + ) + data2 = xr.Dataset( + {"var1": ("dim1", [], attrs2)}, coords={"dim1": ("dim1", [], attrs2)} + ) + + if expect_exception: + with pytest.raises(MergeError, match="combine_attrs"): + actual = xr.merge([data1, data2], combine_attrs=combine_attrs) + else: + actual = xr.merge([data1, data2], combine_attrs=combine_attrs) + expected = xr.Dataset( + {"var1": ("dim1", [], expected_attrs)}, + coords={"dim1": ("dim1", [], expected_attrs)}, + ) + assert_identical(actual, expected) def test_merge_attrs_override_copy(self): @@ -116,6 +207,23 @@ def test_merge_attrs_override_copy(self): ds3.attrs["x"] = 2 assert ds1.x == 0 + def test_merge_attrs_drop_conflicts(self): + ds1 = xr.Dataset(attrs={"a": 0, "b": 0, "c": 0}) + ds2 = xr.Dataset(attrs={"b": 0, "c": 1, "d": 0}) + ds3 = xr.Dataset(attrs={"a": 0, "b": 1, "c": 0, "e": 0}) + + actual = xr.merge([ds1, ds2, ds3], combine_attrs="drop_conflicts") + expected = xr.Dataset(attrs={"a": 0, "d": 0, "e": 0}) + assert_identical(actual, expected) + + def test_merge_attrs_no_conflicts_compat_minimal(self): + """make sure compat="minimal" does not silence errors""" + ds1 = xr.Dataset({"a": ("x", [], {"a": 0})}) + ds2 = xr.Dataset({"a": ("x", [], {"a": 1})}) + + with pytest.raises(xr.MergeError, match="combine_attrs"): + xr.merge([ds1, ds2], combine_attrs="no_conflicts", compat="minimal") + def test_merge_dicts_simple(self): actual = xr.merge([{"foo": 0}, {"bar": "one"}, {"baz": 3.5}]) expected = xr.Dataset({"foo": 0, "bar": "one", "baz": 3.5}) @@ -134,16 +242,16 @@ def test_merge_error(self): def test_merge_alignment_error(self): ds = xr.Dataset(coords={"x": [1, 2]}) other = xr.Dataset(coords={"x": [2, 3]}) - with raises_regex(ValueError, "indexes .* not equal"): + with pytest.raises(ValueError, match=r"indexes .* not equal"): xr.merge([ds, other], join="exact") def test_merge_wrong_input_error(self): - with raises_regex(TypeError, "objects must be an iterable"): + with pytest.raises(TypeError, match=r"objects must be an iterable"): xr.merge([1]) ds = xr.Dataset(coords={"x": [1, 2]}) - with raises_regex(TypeError, "objects must be an iterable"): + with pytest.raises(TypeError, match=r"objects must be an iterable"): xr.merge({"a": ds}) - with raises_regex(TypeError, "objects must be an iterable"): + with pytest.raises(TypeError, match=r"objects must be an iterable"): xr.merge([ds, 1]) def test_merge_no_conflicts_single_var(self): @@ -168,7 +276,7 @@ def test_merge_no_conflicts_single_var(self): xr.merge([ds1, ds3], compat="no_conflicts") def test_merge_no_conflicts_multi_var(self): - data = create_test_data() + data = create_test_data(add_attrs=False) data1 = data.copy(deep=True) data2 = data.copy(deep=True) @@ -187,7 +295,7 @@ def test_merge_no_conflicts_multi_var(self): def test_merge_no_conflicts_preserve_attrs(self): data = xr.Dataset({"x": ([], 0, {"foo": "bar"})}) - actual = xr.merge([data, data]) + actual = xr.merge([data, data], combine_attrs="no_conflicts") assert_identical(data, actual) def test_merge_no_conflicts_broadcast(self): @@ -222,9 +330,9 @@ def test_merge(self): with pytest.raises(ValueError): ds1.merge(ds2.rename({"var3": "var1"})) - with raises_regex(ValueError, "should be coordinates or not"): + with pytest.raises(ValueError, match=r"should be coordinates or not"): data.reset_coords().merge(data) - with raises_regex(ValueError, "should be coordinates or not"): + with pytest.raises(ValueError, match=r"should be coordinates or not"): data.merge(data.reset_coords()) def test_merge_broadcast_equals(self): @@ -254,14 +362,14 @@ def test_merge_compat(self): ds2 = xr.Dataset({"x": [0, 0]}) for compat in ["equals", "identical"]: - with raises_regex(ValueError, "should be coordinates or not"): + with pytest.raises(ValueError, match=r"should be coordinates or not"): ds1.merge(ds2, compat=compat) ds2 = xr.Dataset({"x": ((), 0, {"foo": "bar"})}) with pytest.raises(xr.MergeError): ds1.merge(ds2, compat="identical") - with raises_regex(ValueError, "compat=.* invalid"): + with pytest.raises(ValueError, match=r"compat=.* invalid"): ds1.merge(ds2, compat="foobar") assert ds1.identical(ds1.merge(ds2, compat="override")) @@ -333,3 +441,34 @@ def test_merge_dataarray(self): da = xr.DataArray(data=1, name="b") assert_identical(ds.merge(da), xr.merge([ds, da])) + + @pytest.mark.parametrize( + ["combine_attrs", "attrs1", "attrs2", "expected_attrs", "expect_error"], + # don't need to test thoroughly + ( + ("drop", {"a": 0, "b": 1, "c": 2}, {"a": 1, "b": 2, "c": 3}, {}, False), + ( + "drop_conflicts", + {"a": 0, "b": 1, "c": 2}, + {"b": 2, "c": 2, "d": 3}, + {"a": 0, "c": 2, "d": 3}, + False, + ), + ("override", {"a": 0, "b": 1}, {"a": 1, "b": 2}, {"a": 0, "b": 1}, False), + ("no_conflicts", {"a": 0, "b": 1}, {"a": 0, "b": 2}, None, True), + ("identical", {"a": 0, "b": 1}, {"a": 0, "b": 2}, None, True), + ), + ) + def test_merge_combine_attrs( + self, combine_attrs, attrs1, attrs2, expected_attrs, expect_error + ): + ds1 = xr.Dataset(attrs=attrs1) + ds2 = xr.Dataset(attrs=attrs2) + + if expect_error: + with pytest.raises(xr.MergeError): + ds1.merge(ds2, combine_attrs=combine_attrs) + else: + actual = ds1.merge(ds2, combine_attrs=combine_attrs) + expected = xr.Dataset(attrs=expected_attrs) + assert_identical(actual, expected) diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index 2ab3508b667..1ebcd9ac6f7 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -14,16 +14,16 @@ ) from xarray.core.pycompat import dask_array_type from xarray.tests import ( + _CFTIME_CALENDARS, assert_allclose, assert_array_equal, assert_equal, - raises_regex, + raise_if_dask_computes, requires_bottleneck, requires_cftime, requires_dask, requires_scipy, ) -from xarray.tests.test_cftime_offsets import _CFTIME_CALENDARS @pytest.fixture @@ -174,26 +174,26 @@ def test_interpolate_pd_compat_polynomial(): def test_interpolate_unsorted_index_raises(): vals = np.array([1, 2, 3], dtype=np.float64) expected = xr.DataArray(vals, dims="x", coords={"x": [2, 1, 3]}) - with raises_regex(ValueError, "Index 'x' must be monotonically increasing"): + with pytest.raises(ValueError, match=r"Index 'x' must be monotonically increasing"): expected.interpolate_na(dim="x", method="index") def test_interpolate_no_dim_raises(): da = xr.DataArray(np.array([1, 2, np.nan, 5], dtype=np.float64), dims="x") - with raises_regex(NotImplementedError, "dim is a required argument"): + with pytest.raises(NotImplementedError, match=r"dim is a required argument"): da.interpolate_na(method="linear") def test_interpolate_invalid_interpolator_raises(): da = xr.DataArray(np.array([1, 2, np.nan, 5], dtype=np.float64), dims="x") - with raises_regex(ValueError, "not a valid"): + with pytest.raises(ValueError, match=r"not a valid"): da.interpolate_na(dim="x", method="foo") def test_interpolate_duplicate_values_raises(): data = np.random.randn(2, 3) da = xr.DataArray(data, coords=[("x", ["a", "a"]), ("y", [0, 1, 2])]) - with raises_regex(ValueError, "Index 'x' has duplicate values"): + with pytest.raises(ValueError, match=r"Index 'x' has duplicate values"): da.interpolate_na(dim="x", method="foo") @@ -202,7 +202,7 @@ def test_interpolate_multiindex_raises(): data[1, 1] = np.nan da = xr.DataArray(data, coords=[("x", ["a", "b"]), ("y", [0, 1, 2])]) das = da.stack(z=("x", "y")) - with raises_regex(TypeError, "Index 'z' must be castable to float64"): + with pytest.raises(TypeError, match=r"Index 'z' must be castable to float64"): das.interpolate_na(dim="z") @@ -215,7 +215,7 @@ def test_interpolate_2d_coord_raises(): data = np.random.randn(2, 3) data[1, 1] = np.nan da = xr.DataArray(data, dims=("a", "b"), coords=coords) - with raises_regex(ValueError, "interpolation must be 1D"): + with pytest.raises(ValueError, match=r"interpolation must be 1D"): da.interpolate_na(dim="a", use_coordinate="x") @@ -366,7 +366,7 @@ def test_interpolate_dask_raises_for_invalid_chunk_dim(): da, _ = make_interpolate_example_data((40, 40), 0.5) da = da.chunk({"time": 5}) # this checks for ValueError in dask.array.apply_gufunc - with raises_regex(ValueError, "consists of multiple chunks"): + with pytest.raises(ValueError, match=r"consists of multiple chunks"): da.interpolate_na("time") @@ -392,39 +392,73 @@ def test_ffill(): assert_equal(actual, expected) -@requires_bottleneck +def test_ffill_use_bottleneck(): + da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") + with xr.set_options(use_bottleneck=False): + with pytest.raises(RuntimeError): + da.ffill("x") + + @requires_dask -def test_ffill_dask(): - da, _ = make_interpolate_example_data((40, 40), 0.5) - da = da.chunk({"x": 5}) - actual = da.ffill("time") - expected = da.load().ffill("time") - assert isinstance(actual.data, dask_array_type) - assert_equal(actual, expected) +def test_ffill_use_bottleneck_dask(): + da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") + da = da.chunk({"x": 1}) + with xr.set_options(use_bottleneck=False): + with pytest.raises(RuntimeError): + da.ffill("x") - # with limit - da = da.chunk({"x": 5}) - actual = da.ffill("time", limit=3) - expected = da.load().ffill("time", limit=3) - assert isinstance(actual.data, dask_array_type) - assert_equal(actual, expected) + +def test_bfill_use_bottleneck(): + da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") + with xr.set_options(use_bottleneck=False): + with pytest.raises(RuntimeError): + da.bfill("x") + + +@requires_dask +def test_bfill_use_bottleneck_dask(): + da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") + da = da.chunk({"x": 1}) + with xr.set_options(use_bottleneck=False): + with pytest.raises(RuntimeError): + da.bfill("x") @requires_bottleneck @requires_dask -def test_bfill_dask(): +@pytest.mark.parametrize("method", ["ffill", "bfill"]) +def test_ffill_bfill_dask(method): da, _ = make_interpolate_example_data((40, 40), 0.5) da = da.chunk({"x": 5}) - actual = da.bfill("time") - expected = da.load().bfill("time") - assert isinstance(actual.data, dask_array_type) + + dask_method = getattr(da, method) + numpy_method = getattr(da.compute(), method) + # unchunked axis + with raise_if_dask_computes(): + actual = dask_method("time") + expected = numpy_method("time") + assert_equal(actual, expected) + + # chunked axis + with raise_if_dask_computes(): + actual = dask_method("x") + expected = numpy_method("x") assert_equal(actual, expected) # with limit - da = da.chunk({"x": 5}) - actual = da.bfill("time", limit=3) - expected = da.load().bfill("time", limit=3) - assert isinstance(actual.data, dask_array_type) + with raise_if_dask_computes(): + actual = dask_method("time", limit=3) + expected = numpy_method("time", limit=3) + assert_equal(actual, expected) + + # limit < axis size + with pytest.raises(NotImplementedError): + actual = dask_method("x", limit=2) + + # limit > axis size + with raise_if_dask_computes(): + actual = dask_method("x", limit=41) + expected = numpy_method("x", limit=41) assert_equal(actual, expected) @@ -540,6 +574,7 @@ def test_get_clean_interp_index_dt(cf_da, calendar, freq): np.testing.assert_array_equal(gi, si) +@requires_cftime def test_get_clean_interp_index_potential_overflow(): da = xr.DataArray( [0, 1, 2], @@ -570,27 +605,30 @@ def da_time(): def test_interpolate_na_max_gap_errors(da_time): - with raises_regex( - NotImplementedError, "max_gap not implemented for unlabeled coordinates" + with pytest.raises( + NotImplementedError, match=r"max_gap not implemented for unlabeled coordinates" ): da_time.interpolate_na("t", max_gap=1) - with raises_regex(ValueError, "max_gap must be a scalar."): + with pytest.raises(ValueError, match=r"max_gap must be a scalar."): da_time.interpolate_na("t", max_gap=(1,)) da_time["t"] = pd.date_range("2001-01-01", freq="H", periods=11) - with raises_regex(TypeError, "Expected value of type str"): + with pytest.raises(TypeError, match=r"Expected value of type str"): da_time.interpolate_na("t", max_gap=1) - with raises_regex(TypeError, "Expected integer or floating point"): + with pytest.raises(TypeError, match=r"Expected integer or floating point"): da_time.interpolate_na("t", max_gap="1H", use_coordinate=False) - with raises_regex(ValueError, "Could not convert 'huh' to timedelta64"): + with pytest.raises(ValueError, match=r"Could not convert 'huh' to timedelta64"): da_time.interpolate_na("t", max_gap="huh") @requires_bottleneck -@pytest.mark.parametrize("time_range_func", [pd.date_range, xr.cftime_range]) +@pytest.mark.parametrize( + "time_range_func", + [pd.date_range, pytest.param(xr.cftime_range, marks=requires_cftime)], +) @pytest.mark.parametrize("transform", [lambda x: x, lambda x: x.to_dataset(name="a")]) @pytest.mark.parametrize( "max_gap", ["3H", np.timedelta64(3, "h"), pd.to_timedelta("3H")] diff --git a/xarray/tests/test_nputils.py b/xarray/tests/test_nputils.py index ccb825dc7e9..3c9c92ae2ba 100644 --- a/xarray/tests/test_nputils.py +++ b/xarray/tests/test_nputils.py @@ -1,8 +1,7 @@ import numpy as np -import pytest from numpy.testing import assert_array_equal -from xarray.core.nputils import NumpyVIndexAdapter, _is_contiguous, rolling_window +from xarray.core.nputils import NumpyVIndexAdapter, _is_contiguous def test_is_contiguous(): @@ -29,38 +28,3 @@ def test_vindex(): vindex[[0, 1], [0, 1], :] = vindex[[0, 1], [0, 1], :] vindex[[0, 1], :, [0, 1]] = vindex[[0, 1], :, [0, 1]] vindex[:, [0, 1], [0, 1]] = vindex[:, [0, 1], [0, 1]] - - -def test_rolling(): - x = np.array([1, 2, 3, 4], dtype=float) - - actual = rolling_window(x, axis=-1, window=3, center=True, fill_value=np.nan) - expected = np.array( - [[np.nan, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, np.nan]], dtype=float - ) - assert_array_equal(actual, expected) - - actual = rolling_window(x, axis=-1, window=3, center=False, fill_value=0.0) - expected = np.array([[0, 0, 1], [0, 1, 2], [1, 2, 3], [2, 3, 4]], dtype=float) - assert_array_equal(actual, expected) - - x = np.stack([x, x * 1.1]) - actual = rolling_window(x, axis=-1, window=3, center=False, fill_value=0.0) - expected = np.stack([expected, expected * 1.1], axis=0) - assert_array_equal(actual, expected) - - -@pytest.mark.parametrize("center", [[True, True], [False, False]]) -@pytest.mark.parametrize("axis", [(0, 1), (1, 2), (2, 0)]) -def test_nd_rolling(center, axis): - x = np.arange(7 * 6 * 8).reshape(7, 6, 8).astype(float) - window = [3, 3] - actual = rolling_window( - x, axis=axis, window=window, center=center, fill_value=np.nan - ) - expected = x - for ax, win, cent in zip(axis, window, center): - expected = rolling_window( - expected, axis=ax, window=win, center=cent, fill_value=np.nan - ) - assert_array_equal(actual, expected) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 363a05310ad..2fa2829b12a 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2,6 +2,7 @@ import inspect from copy import copy from datetime import datetime +from typing import Any, Dict, Union import numpy as np import pandas as pd @@ -16,6 +17,7 @@ _build_discrete_cmap, _color_palette, _determine_cmap_params, + _maybe_gca, get_axis, label_from_attrs, ) @@ -24,10 +26,10 @@ assert_array_equal, assert_equal, has_nc_time_axis, - raises_regex, requires_cartopy, requires_cftime, requires_matplotlib, + requires_matplotlib_3_3_0, requires_nc_time_axis, requires_seaborn, ) @@ -36,13 +38,14 @@ try: import matplotlib as mpl import matplotlib.pyplot as plt + import mpl_toolkits # type: ignore except ImportError: pass try: - import cartopy as ctpy # type: ignore + import cartopy except ImportError: - ctpy = None + pass @contextlib.contextmanager @@ -132,8 +135,8 @@ def setup(self): # Remove all matplotlib figures plt.close("all") - def pass_in_axis(self, plotmethod): - fig, axes = plt.subplots(ncols=2) + def pass_in_axis(self, plotmethod, subplot_kw=None): + fig, axes = plt.subplots(ncols=2, subplot_kw=subplot_kw) plotmethod(ax=axes[0]) assert axes[0].has_data() @@ -185,10 +188,10 @@ def test_label_from_attrs(self): def test1d(self): self.darray[:, 0, 0].plot() - with raises_regex(ValueError, "x must be one of None, 'dim_0'"): + with pytest.raises(ValueError, match=r"x must be one of None, 'dim_0'"): self.darray[:, 0, 0].plot(x="dim_1") - with raises_regex(TypeError, "complex128"): + with pytest.raises(TypeError, match=r"complex128"): (self.darray[:, 0, 0] + 1j).plot() def test_1d_bool(self): @@ -204,14 +207,14 @@ def test_1d_x_y_kw(self): for aa, (x, y) in enumerate(xy): da.plot(x=x, y=y, ax=ax.flat[aa]) - with raises_regex(ValueError, "Cannot specify both"): + with pytest.raises(ValueError, match=r"Cannot specify both"): da.plot(x="z", y="z") error_msg = "must be one of None, 'z'" - with raises_regex(ValueError, f"x {error_msg}"): + with pytest.raises(ValueError, match=rf"x {error_msg}"): da.plot(x="f") - with raises_regex(ValueError, f"y {error_msg}"): + with pytest.raises(ValueError, match=rf"y {error_msg}"): da.plot(y="f") def test_multiindex_level_as_coord(self): @@ -267,8 +270,17 @@ def test_line_plot_along_1d_coord(self): line = da.plot(y="time", hue="x")[0] assert_array_equal(line.get_ydata(), da.coords["time"].values) + def test_line_plot_wrong_hue(self): + da = xr.DataArray( + data=np.array([[0, 1], [5, 9]]), + dims=["x", "t"], + ) + + with pytest.raises(ValueError, match="hue must be one of"): + da.plot(x="t", hue="wrong_coord") + def test_2d_line(self): - with raises_regex(ValueError, "hue"): + with pytest.raises(ValueError, match=r"hue"): self.darray[:, :, 0].plot.line() self.darray[:, :, 0].plot.line(hue="dim_1") @@ -277,7 +289,7 @@ def test_2d_line(self): self.darray[:, :, 0].plot.line(x="dim_0", hue="dim_1") self.darray[:, :, 0].plot.line(y="dim_0", hue="dim_1") - with raises_regex(ValueError, "Cannot"): + with pytest.raises(ValueError, match=r"Cannot"): self.darray[:, :, 0].plot.line(x="dim_1", y="dim_0", hue="dim_1") def test_2d_line_accepts_legend_kw(self): @@ -360,6 +372,34 @@ def test2d_1d_2d_coordinates_contourf(self): a.plot.contourf(x="time", y="depth") a.plot.contourf(x="depth", y="time") + def test2d_1d_2d_coordinates_pcolormesh(self): + # Test with equal coordinates to catch bug from #5097 + sz = 10 + y2d, x2d = np.meshgrid(np.arange(sz), np.arange(sz)) + a = DataArray( + easy_array((sz, sz)), + dims=["x", "y"], + coords={"x2d": (["x", "y"], x2d), "y2d": (["x", "y"], y2d)}, + ) + + for x, y in [ + ("x", "y"), + ("y", "x"), + ("x2d", "y"), + ("y", "x2d"), + ("x", "y2d"), + ("y2d", "x"), + ("x2d", "y2d"), + ("y2d", "x2d"), + ]: + p = a.plot.pcolormesh(x=x, y=y) + v = p.get_paths()[0].vertices + + # Check all vertices are different, except last vertex which should be the + # same as the first + _, unique_counts = np.unique(v[:-1], axis=0, return_counts=True) + assert np.all(unique_counts == 1) + def test_contourf_cmap_set(self): a = DataArray(easy_array((4, 4)), dims=["z", "time"]) @@ -454,6 +494,39 @@ def test__infer_interval_breaks(self): with pytest.raises(ValueError): _infer_interval_breaks(np.array([0, 2, 1]), check_monotonic=True) + def test__infer_interval_breaks_logscale(self): + """ + Check if interval breaks are defined in the logspace if scale="log" + """ + # Check for 1d arrays + x = np.logspace(-4, 3, 8) + expected_interval_breaks = 10 ** np.linspace(-4.5, 3.5, 9) + np.testing.assert_allclose( + _infer_interval_breaks(x, scale="log"), expected_interval_breaks + ) + + # Check for 2d arrays + x = np.logspace(-4, 3, 8) + y = np.linspace(-5, 5, 11) + x, y = np.meshgrid(x, y) + expected_interval_breaks = np.vstack([10 ** np.linspace(-4.5, 3.5, 9)] * 12) + x = _infer_interval_breaks(x, axis=1, scale="log") + x = _infer_interval_breaks(x, axis=0, scale="log") + np.testing.assert_allclose(x, expected_interval_breaks) + + def test__infer_interval_breaks_logscale_invalid_coords(self): + """ + Check error is raised when passing non-positive coordinates with logscale + """ + # Check if error is raised after a zero value in the array + x = np.linspace(0, 5, 6) + with pytest.raises(ValueError): + _infer_interval_breaks(x, scale="log") + # Check if error is raised after nagative values in the array + x = np.linspace(-5, 5, 11) + with pytest.raises(ValueError): + _infer_interval_breaks(x, scale="log") + def test_geo_data(self): # Regression test for gh2250 # Realistic coordinates taken from the example dataset @@ -509,10 +582,10 @@ def test_convenient_facetgrid(self): for ax in g.axes.flat: assert ax.has_data() - with raises_regex(ValueError, "[Ff]acet"): + with pytest.raises(ValueError, match=r"[Ff]acet"): d.plot(x="x", y="y", col="z", ax=plt.gca()) - with raises_regex(ValueError, "[Ff]acet"): + with pytest.raises(ValueError, match=r"[Ff]acet"): d[0].plot(x="x", y="y", col="z", ax=plt.gca()) @pytest.mark.slow @@ -546,16 +619,16 @@ def test_plot_size(self): self.darray.plot(size=5, aspect=2) assert tuple(plt.gcf().get_size_inches()) == (10, 5) - with raises_regex(ValueError, "cannot provide both"): + with pytest.raises(ValueError, match=r"cannot provide both"): self.darray.plot(ax=plt.gca(), figsize=(3, 4)) - with raises_regex(ValueError, "cannot provide both"): + with pytest.raises(ValueError, match=r"cannot provide both"): self.darray.plot(size=5, figsize=(3, 4)) - with raises_regex(ValueError, "cannot provide both"): + with pytest.raises(ValueError, match=r"cannot provide both"): self.darray.plot(size=5, ax=plt.gca()) - with raises_regex(ValueError, "cannot provide `aspect`"): + with pytest.raises(ValueError, match=r"cannot provide `aspect`"): self.darray.plot(aspect=1) @pytest.mark.slow @@ -569,7 +642,7 @@ def test_convenient_facetgrid_4d(self): for ax in g.axes.flat: assert ax.has_data() - with raises_regex(ValueError, "[Ff]acet"): + with pytest.raises(ValueError, match=r"[Ff]acet"): d.plot(x="x", y="y", col="columns", ax=plt.gca()) def test_coord_with_interval(self): @@ -644,10 +717,9 @@ def test_format_string(self): def test_can_pass_in_axis(self): self.pass_in_axis(self.darray.plot.line) - def test_nonnumeric_index_raises_typeerror(self): + def test_nonnumeric_index(self): a = DataArray([1, 2, 3], {"letter": ["a", "b", "c"]}, dims="letter") - with raises_regex(TypeError, r"[Pp]lot"): - a.plot.line() + a.plot.line() def test_primitive_returned(self): p = self.darray.plot.line() @@ -1070,6 +1142,9 @@ class Common2dMixin: Should have the same name as the method. """ + # Needs to be overridden in TestSurface for facet grid plots + subplot_kws: Union[Dict[Any, Any], None] = None + @pytest.fixture(autouse=True) def setUp(self): da = DataArray( @@ -1102,26 +1177,30 @@ def test_label_names(self): assert "y_long_name [y_units]" == plt.gca().get_ylabel() def test_1d_raises_valueerror(self): - with raises_regex(ValueError, r"DataArray must be 2d"): + with pytest.raises(ValueError, match=r"DataArray must be 2d"): self.plotfunc(self.darray[0, :]) def test_bool(self): xr.ones_like(self.darray, dtype=bool).plot() def test_complex_raises_typeerror(self): - with raises_regex(TypeError, "complex128"): + with pytest.raises(TypeError, match=r"complex128"): (self.darray + 1j).plot() def test_3d_raises_valueerror(self): a = DataArray(easy_array((2, 3, 4))) if self.plotfunc.__name__ == "imshow": pytest.skip() - with raises_regex(ValueError, r"DataArray must be 2d"): + with pytest.raises(ValueError, match=r"DataArray must be 2d"): self.plotfunc(a) - def test_nonnumeric_index_raises_typeerror(self): + def test_nonnumeric_index(self): a = DataArray(easy_array((3, 2)), coords=[["a", "b", "c"], ["d", "e"]]) - with raises_regex(TypeError, r"[Pp]lot"): + if self.plotfunc.__name__ == "surface": + # ax.plot_surface errors with nonnumerics: + with pytest.raises(Exception): + self.plotfunc(a) + else: self.plotfunc(a) def test_multiindex_raises_typeerror(self): @@ -1131,7 +1210,7 @@ def test_multiindex_raises_typeerror(self): coords=dict(x=("x", [0, 1, 2]), a=("y", [0, 1]), b=("y", [2, 3])), ) a = a.set_index(y=("a", "b")) - with raises_regex(TypeError, r"[Pp]lot"): + with pytest.raises(TypeError, match=r"[Pp]lot"): self.plotfunc(a) def test_can_pass_in_axis(self): @@ -1243,15 +1322,15 @@ def test_positional_coord_string(self): def test_bad_x_string_exception(self): - with raises_regex(ValueError, "x and y cannot be equal."): + with pytest.raises(ValueError, match=r"x and y cannot be equal."): self.plotmethod(x="y", y="y") error_msg = "must be one of None, 'x', 'x2d', 'y', 'y2d'" - with raises_regex(ValueError, f"x {error_msg}"): + with pytest.raises(ValueError, match=rf"x {error_msg}"): self.plotmethod("not_a_real_dim", "y") - with raises_regex(ValueError, f"x {error_msg}"): + with pytest.raises(ValueError, match=rf"x {error_msg}"): self.plotmethod(x="not_a_real_dim") - with raises_regex(ValueError, f"y {error_msg}"): + with pytest.raises(ValueError, match=rf"y {error_msg}"): self.plotmethod(y="not_a_real_dim") self.darray.coords["z"] = 100 @@ -1301,10 +1380,10 @@ def test_multiindex_level_as_coord(self): assert x == ax.get_xlabel() assert y == ax.get_ylabel() - with raises_regex(ValueError, "levels of the same MultiIndex"): + with pytest.raises(ValueError, match=r"levels of the same MultiIndex"): self.plotfunc(da, x="a", y="b") - with raises_regex(ValueError, "y must be one of None, 'a', 'b', 'x'"): + with pytest.raises(ValueError, match=r"y must be one of None, 'a', 'b', 'x'"): self.plotfunc(da, x="a", y="y") def test_default_title(self): @@ -1385,7 +1464,7 @@ def test_colorbar_kwargs(self): def test_verbose_facetgrid(self): a = easy_array((10, 15, 3)) d = DataArray(a, dims=["y", "x", "z"]) - g = xplt.FacetGrid(d, col="z") + g = xplt.FacetGrid(d, col="z", subplot_kws=self.subplot_kws) g.map_dataarray(self.plotfunc, "x", "y") for ax in g.axes.flat: assert ax.has_data() @@ -1475,7 +1554,7 @@ def test_facetgrid_cbar_kwargs(self): ) # catch contour case - if hasattr(g, "cbar"): + if g.cbar is not None: assert get_colorbar_label(g.cbar) == "test_label" def test_facetgrid_no_cbar_ax(self): @@ -1599,7 +1678,7 @@ def test_cmap_and_color_both(self): self.plotmethod(colors="k", cmap="RdBu") def list_of_colors_in_cmap_raises_error(self): - with raises_regex(ValueError, "list of colors"): + with pytest.raises(ValueError, match=r"list of colors"): self.plotmethod(cmap=["k", "b"]) @pytest.mark.slow @@ -1648,6 +1727,52 @@ def test_dont_infer_interval_breaks_for_cartopy(self): assert artist.get_array().size <= self.darray.size +class TestPcolormeshLogscale(PlotTestCase): + """ + Test pcolormesh axes when x and y are in logscale + """ + + plotfunc = staticmethod(xplt.pcolormesh) + + @pytest.fixture(autouse=True) + def setUp(self): + self.boundaries = (-1, 9, -4, 3) + shape = (8, 11) + x = np.logspace(self.boundaries[0], self.boundaries[1], shape[1]) + y = np.logspace(self.boundaries[2], self.boundaries[3], shape[0]) + da = DataArray( + easy_array(shape, start=-1), + dims=["y", "x"], + coords={"y": y, "x": x}, + name="testvar", + ) + self.darray = da + + def test_interval_breaks_logspace(self): + """ + Check if the outer vertices of the pcolormesh are the expected values + + Checks bugfix for #5333 + """ + artist = self.darray.plot.pcolormesh(xscale="log", yscale="log") + + # Grab the coordinates of the vertices of the Patches + x_vertices = [p.vertices[:, 0] for p in artist.properties()["paths"]] + y_vertices = [p.vertices[:, 1] for p in artist.properties()["paths"]] + + # Get the maximum and minimum values for each set of vertices + xmin, xmax = np.min(x_vertices), np.max(x_vertices) + ymin, ymax = np.min(y_vertices), np.max(y_vertices) + + # Check if they are equal to 10 to the power of the outer value of its + # corresponding axis plus or minus the interval in the logspace + log_interval = 0.5 + np.testing.assert_allclose(xmin, 10 ** (self.boundaries[0] - log_interval)) + np.testing.assert_allclose(xmax, 10 ** (self.boundaries[1] + log_interval)) + np.testing.assert_allclose(ymin, 10 ** (self.boundaries[2] - log_interval)) + np.testing.assert_allclose(ymax, 10 ** (self.boundaries[3] + log_interval)) + + @pytest.mark.slow class TestImshow(Common2dMixin, PlotTestCase): @@ -1671,7 +1796,7 @@ def test_default_aspect_is_auto(self): @pytest.mark.slow def test_cannot_change_mpl_aspect(self): - with raises_regex(ValueError, "not available in xarray"): + with pytest.raises(ValueError, match=r"not available in xarray"): self.darray.plot.imshow(aspect="equal") # with numbers we fall back to fig control @@ -1691,7 +1816,7 @@ def test_seaborn_palette_needs_levels(self): self.plotmethod(cmap="husl") def test_2d_coord_names(self): - with raises_regex(ValueError, "requires 1D coordinates"): + with pytest.raises(ValueError, match=r"requires 1D coordinates"): self.plotmethod(x="x2d", y="y2d") def test_plot_rgb_image(self): @@ -1785,6 +1910,95 @@ def test_origin_overrides_xyincrease(self): assert plt.ylim()[0] < 0 +class TestSurface(Common2dMixin, PlotTestCase): + + plotfunc = staticmethod(xplt.surface) + subplot_kws = {"projection": "3d"} + + def test_primitive_artist_returned(self): + artist = self.plotmethod() + assert isinstance(artist, mpl_toolkits.mplot3d.art3d.Poly3DCollection) + + @pytest.mark.slow + def test_2d_coord_names(self): + self.plotmethod(x="x2d", y="y2d") + # make sure labels came out ok + ax = plt.gca() + assert "x2d" == ax.get_xlabel() + assert "y2d" == ax.get_ylabel() + assert f"{self.darray.long_name} [{self.darray.units}]" == ax.get_zlabel() + + def test_xyincrease_false_changes_axes(self): + # Does not make sense for surface plots + pytest.skip("does not make sense for surface plots") + + def test_xyincrease_true_changes_axes(self): + # Does not make sense for surface plots + pytest.skip("does not make sense for surface plots") + + def test_can_pass_in_axis(self): + self.pass_in_axis(self.plotmethod, subplot_kw={"projection": "3d"}) + + def test_default_cmap(self): + # Does not make sense for surface plots with default arguments + pytest.skip("does not make sense for surface plots") + + def test_diverging_color_limits(self): + # Does not make sense for surface plots with default arguments + pytest.skip("does not make sense for surface plots") + + def test_colorbar_kwargs(self): + # Does not make sense for surface plots with default arguments + pytest.skip("does not make sense for surface plots") + + def test_cmap_and_color_both(self): + # Does not make sense for surface plots with default arguments + pytest.skip("does not make sense for surface plots") + + def test_seaborn_palette_as_cmap(self): + # seaborn does not work with mpl_toolkits.mplot3d + with pytest.raises(ValueError): + super().test_seaborn_palette_as_cmap() + + # Need to modify this test for surface(), because all subplots should have labels, + # not just left and bottom + @pytest.mark.filterwarnings("ignore:tight_layout cannot") + def test_convenient_facetgrid(self): + a = easy_array((10, 15, 4)) + d = DataArray(a, dims=["y", "x", "z"]) + g = self.plotfunc(d, x="x", y="y", col="z", col_wrap=2) + + assert_array_equal(g.axes.shape, [2, 2]) + for (y, x), ax in np.ndenumerate(g.axes): + assert ax.has_data() + assert "y" == ax.get_ylabel() + assert "x" == ax.get_xlabel() + + # Infering labels + g = self.plotfunc(d, col="z", col_wrap=2) + assert_array_equal(g.axes.shape, [2, 2]) + for (y, x), ax in np.ndenumerate(g.axes): + assert ax.has_data() + assert "y" == ax.get_ylabel() + assert "x" == ax.get_xlabel() + + @requires_matplotlib_3_3_0 + def test_viridis_cmap(self): + return super().test_viridis_cmap() + + @requires_matplotlib_3_3_0 + def test_can_change_default_cmap(self): + return super().test_can_change_default_cmap() + + @requires_matplotlib_3_3_0 + def test_colorbar_default_label(self): + return super().test_colorbar_default_label() + + @requires_matplotlib_3_3_0 + def test_facetgrid_map_only_appends_mappables(self): + return super().test_facetgrid_map_only_appends_mappables() + + class TestFacetGrid(PlotTestCase): @pytest.fixture(autouse=True) def setUp(self): @@ -1852,7 +2066,7 @@ def test_empty_cell(self): @pytest.mark.slow def test_norow_nocol_error(self): - with raises_regex(ValueError, r"[Rr]ow"): + with pytest.raises(ValueError, match=r"[Rr]ow"): xplt.FacetGrid(self.darray) @pytest.mark.slow @@ -1873,7 +2087,7 @@ def test_float_index(self): @pytest.mark.slow def test_nonunique_index_error(self): self.darray.coords["z"] = [0.1, 0.2, 0.2] - with raises_regex(ValueError, r"[Uu]nique"): + with pytest.raises(ValueError, match=r"[Uu]nique"): xplt.FacetGrid(self.darray, col="z") @pytest.mark.slow @@ -1940,10 +2154,10 @@ def test_figure_size(self): g = xplt.FacetGrid(self.darray, col="z", figsize=(9, 4)) assert_array_equal(g.fig.get_size_inches(), (9, 4)) - with raises_regex(ValueError, "cannot provide both"): + with pytest.raises(ValueError, match=r"cannot provide both"): g = xplt.plot(self.darray, row=2, col="z", figsize=(6, 4), size=6) - with raises_regex(ValueError, "Can't use"): + with pytest.raises(ValueError, match=r"Can't use"): g = xplt.plot(self.darray, row=2, col="z", ax=plt.gca(), size=6) @pytest.mark.slow @@ -2152,6 +2366,121 @@ def test_wrong_num_of_dimensions(self): self.darray.plot.line(row="row", hue="hue") +@requires_matplotlib +class TestDatasetQuiverPlots(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self): + das = [ + DataArray( + np.random.randn(3, 3, 4, 4), + dims=["x", "y", "row", "col"], + coords=[range(k) for k in [3, 3, 4, 4]], + ) + for _ in [1, 2] + ] + ds = Dataset({"u": das[0], "v": das[1]}) + ds.x.attrs["units"] = "xunits" + ds.y.attrs["units"] = "yunits" + ds.col.attrs["units"] = "colunits" + ds.row.attrs["units"] = "rowunits" + ds.u.attrs["units"] = "uunits" + ds.v.attrs["units"] = "vunits" + ds["mag"] = np.hypot(ds.u, ds.v) + self.ds = ds + + def test_quiver(self): + with figure_context(): + hdl = self.ds.isel(row=0, col=0).plot.quiver(x="x", y="y", u="u", v="v") + assert isinstance(hdl, mpl.quiver.Quiver) + with pytest.raises(ValueError, match=r"specify x, y, u, v"): + self.ds.isel(row=0, col=0).plot.quiver(x="x", y="y", u="u") + + with pytest.raises(ValueError, match=r"hue_style"): + self.ds.isel(row=0, col=0).plot.quiver( + x="x", y="y", u="u", v="v", hue="mag", hue_style="discrete" + ) + + def test_facetgrid(self): + with figure_context(): + fg = self.ds.plot.quiver( + x="x", y="y", u="u", v="v", row="row", col="col", scale=1, hue="mag" + ) + for handle in fg._mappables: + assert isinstance(handle, mpl.quiver.Quiver) + assert "uunits" in fg.quiverkey.text.get_text() + + with figure_context(): + fg = self.ds.plot.quiver( + x="x", + y="y", + u="u", + v="v", + row="row", + col="col", + scale=1, + hue="mag", + add_guide=False, + ) + assert fg.quiverkey is None + with pytest.raises(ValueError, match=r"Please provide scale"): + self.ds.plot.quiver(x="x", y="y", u="u", v="v", row="row", col="col") + + +@requires_matplotlib +class TestDatasetStreamplotPlots(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self): + das = [ + DataArray( + np.random.randn(3, 3, 2, 2), + dims=["x", "y", "row", "col"], + coords=[range(k) for k in [3, 3, 2, 2]], + ) + for _ in [1, 2] + ] + ds = Dataset({"u": das[0], "v": das[1]}) + ds.x.attrs["units"] = "xunits" + ds.y.attrs["units"] = "yunits" + ds.col.attrs["units"] = "colunits" + ds.row.attrs["units"] = "rowunits" + ds.u.attrs["units"] = "uunits" + ds.v.attrs["units"] = "vunits" + ds["mag"] = np.hypot(ds.u, ds.v) + self.ds = ds + + def test_streamline(self): + with figure_context(): + hdl = self.ds.isel(row=0, col=0).plot.streamplot(x="x", y="y", u="u", v="v") + assert isinstance(hdl, mpl.collections.LineCollection) + with pytest.raises(ValueError, match=r"specify x, y, u, v"): + self.ds.isel(row=0, col=0).plot.streamplot(x="x", y="y", u="u") + + with pytest.raises(ValueError, match=r"hue_style"): + self.ds.isel(row=0, col=0).plot.streamplot( + x="x", y="y", u="u", v="v", hue="mag", hue_style="discrete" + ) + + def test_facetgrid(self): + with figure_context(): + fg = self.ds.plot.streamplot( + x="x", y="y", u="u", v="v", row="row", col="col", hue="mag" + ) + for handle in fg._mappables: + assert isinstance(handle, mpl.collections.LineCollection) + + with figure_context(): + fg = self.ds.plot.streamplot( + x="x", + y="y", + u="u", + v="v", + row="row", + col="col", + hue="mag", + add_guide=False, + ) + + @requires_matplotlib class TestDatasetScatterPlots(PlotTestCase): @pytest.fixture(autouse=True) @@ -2194,7 +2523,13 @@ def test_accessor(self): def test_add_guide(self, add_guide, hue_style, legend, colorbar): meta_data = _infer_meta_data( - self.ds, x="A", y="B", hue="hue", hue_style=hue_style, add_guide=add_guide + self.ds, + x="A", + y="B", + hue="hue", + hue_style=hue_style, + add_guide=add_guide, + funcname="scatter", ) assert meta_data["add_legend"] is legend assert meta_data["add_colorbar"] is colorbar @@ -2273,6 +2608,9 @@ def test_facetgrid_hue_style(self): def test_scatter(self, x, y, hue, markersize): self.ds.plot.scatter(x, y, hue=hue, markersize=markersize) + with pytest.raises(ValueError, match=r"u, v"): + self.ds.plot.scatter(x, y, u="col", v="row") + def test_non_numeric_legend(self): ds2 = self.ds.copy() ds2["hue"] = ["a", "b", "c", "d"] @@ -2372,74 +2710,99 @@ def setUp(self): self.darray = darray def test_ncaxis_notinstalled_line_plot(self): - with raises_regex(ImportError, "optional `nc-time-axis`"): + with pytest.raises(ImportError, match=r"optional `nc-time-axis`"): self.darray.plot.line() -test_da_list = [ - DataArray(easy_array((10,))), - DataArray(easy_array((10, 3))), - DataArray(easy_array((10, 3, 2))), -] - - @requires_matplotlib class TestAxesKwargs: - @pytest.mark.parametrize("da", test_da_list) + @pytest.fixture(params=[1, 2, 3]) + def data_array(self, request): + """ + Return a simple DataArray + """ + dims = request.param + if dims == 1: + return DataArray(easy_array((10,))) + if dims == 2: + return DataArray(easy_array((10, 3))) + if dims == 3: + return DataArray(easy_array((10, 3, 2))) + + @pytest.fixture(params=[1, 2]) + def data_array_logspaced(self, request): + """ + Return a simple DataArray with logspaced coordinates + """ + dims = request.param + if dims == 1: + return DataArray( + np.arange(7), dims=("x",), coords={"x": np.logspace(-3, 3, 7)} + ) + if dims == 2: + return DataArray( + np.arange(16).reshape(4, 4), + dims=("y", "x"), + coords={"x": np.logspace(-1, 2, 4), "y": np.logspace(-5, -1, 4)}, + ) + @pytest.mark.parametrize("xincrease", [True, False]) - def test_xincrease_kwarg(self, da, xincrease): + def test_xincrease_kwarg(self, data_array, xincrease): with figure_context(): - da.plot(xincrease=xincrease) + data_array.plot(xincrease=xincrease) assert plt.gca().xaxis_inverted() == (not xincrease) - @pytest.mark.parametrize("da", test_da_list) @pytest.mark.parametrize("yincrease", [True, False]) - def test_yincrease_kwarg(self, da, yincrease): + def test_yincrease_kwarg(self, data_array, yincrease): with figure_context(): - da.plot(yincrease=yincrease) + data_array.plot(yincrease=yincrease) assert plt.gca().yaxis_inverted() == (not yincrease) - @pytest.mark.parametrize("da", test_da_list) - @pytest.mark.parametrize("xscale", ["linear", "log", "logit", "symlog"]) - def test_xscale_kwarg(self, da, xscale): + @pytest.mark.parametrize("xscale", ["linear", "logit", "symlog"]) + def test_xscale_kwarg(self, data_array, xscale): with figure_context(): - da.plot(xscale=xscale) + data_array.plot(xscale=xscale) assert plt.gca().get_xscale() == xscale - @pytest.mark.parametrize( - "da", [DataArray(easy_array((10,))), DataArray(easy_array((10, 3)))] - ) - @pytest.mark.parametrize("yscale", ["linear", "log", "logit", "symlog"]) - def test_yscale_kwarg(self, da, yscale): + @pytest.mark.parametrize("yscale", ["linear", "logit", "symlog"]) + def test_yscale_kwarg(self, data_array, yscale): + with figure_context(): + data_array.plot(yscale=yscale) + assert plt.gca().get_yscale() == yscale + + def test_xscale_log_kwarg(self, data_array_logspaced): + xscale = "log" with figure_context(): - da.plot(yscale=yscale) + data_array_logspaced.plot(xscale=xscale) + assert plt.gca().get_xscale() == xscale + + def test_yscale_log_kwarg(self, data_array_logspaced): + yscale = "log" + with figure_context(): + data_array_logspaced.plot(yscale=yscale) assert plt.gca().get_yscale() == yscale - @pytest.mark.parametrize("da", test_da_list) - def test_xlim_kwarg(self, da): + def test_xlim_kwarg(self, data_array): with figure_context(): expected = (0.0, 1000.0) - da.plot(xlim=[0, 1000]) + data_array.plot(xlim=[0, 1000]) assert plt.gca().get_xlim() == expected - @pytest.mark.parametrize("da", test_da_list) - def test_ylim_kwarg(self, da): + def test_ylim_kwarg(self, data_array): with figure_context(): - da.plot(ylim=[0, 1000]) + data_array.plot(ylim=[0, 1000]) expected = (0.0, 1000.0) assert plt.gca().get_ylim() == expected - @pytest.mark.parametrize("da", test_da_list) - def test_xticks_kwarg(self, da): + def test_xticks_kwarg(self, data_array): with figure_context(): - da.plot(xticks=np.arange(5)) + data_array.plot(xticks=np.arange(5)) expected = np.arange(5).tolist() assert_array_equal(plt.gca().get_xticks(), expected) - @pytest.mark.parametrize("da", test_da_list) - def test_yticks_kwarg(self, da): + def test_yticks_kwarg(self, data_array): with figure_context(): - da.plot(yticks=np.arange(5)) + data_array.plot(yticks=np.arange(5)) expected = np.arange(5) assert_array_equal(plt.gca().get_yticks(), expected) @@ -2517,10 +2880,76 @@ def test_get_axis(): @requires_cartopy def test_get_axis_cartopy(): - kwargs = {"projection": ctpy.crs.PlateCarree()} + kwargs = {"projection": cartopy.crs.PlateCarree()} with figure_context(): ax = get_axis(**kwargs) - assert isinstance(ax, ctpy.mpl.geoaxes.GeoAxesSubplot) + assert isinstance(ax, cartopy.mpl.geoaxes.GeoAxesSubplot) + + +@requires_matplotlib +def test_maybe_gca(): + + with figure_context(): + ax = _maybe_gca(aspect=1) + + assert isinstance(ax, mpl.axes.Axes) + assert ax.get_aspect() == 1 + + with figure_context(): + + # create figure without axes + plt.figure() + ax = _maybe_gca(aspect=1) + + assert isinstance(ax, mpl.axes.Axes) + assert ax.get_aspect() == 1 + + with figure_context(): + existing_axes = plt.axes() + ax = _maybe_gca(aspect=1) + + # re-uses the existing axes + assert existing_axes == ax + # kwargs are ignored when reusing axes + assert ax.get_aspect() == "auto" + + +@requires_matplotlib +@pytest.mark.parametrize( + "x, y, z, hue, markersize, row, col, add_legend, add_colorbar", + [ + ("A", "B", None, None, None, None, None, None, None), + ("B", "A", None, "w", None, None, None, True, None), + ("A", "B", None, "y", "x", None, None, True, True), + ("A", "B", "z", None, None, None, None, None, None), + ("B", "A", "z", "w", None, None, None, True, None), + ("A", "B", "z", "y", "x", None, None, True, True), + ("A", "B", "z", "y", "x", "w", None, True, True), + ], +) +def test_datarray_scatter(x, y, z, hue, markersize, row, col, add_legend, add_colorbar): + """Test datarray scatter. Merge with TestPlot1D eventually.""" + ds = xr.tutorial.scatter_example_dataset() + + extra_coords = [v for v in [x, hue, markersize] if v is not None] + + # Base coords: + coords = dict(ds.coords) + + # Add extra coords to the DataArray: + coords.update({v: ds[v] for v in extra_coords}) + + darray = xr.DataArray(ds[y], coords=coords) + + with figure_context(): + darray.plot._scatter( + x=x, + z=z, + hue=hue, + markersize=markersize, + add_legend=add_legend, + add_colorbar=add_colorbar, + ) @requires_matplotlib @@ -2555,7 +2984,7 @@ def test_cftime_grouping(self): ) # TODO: can't plot single vector with plot2d when axis is CFTime - with raises_regex(TypeError, "unsupported operand"): + with pytest.raises(TypeError, match="unsupported operand"): with figure_context(): ds.variable.sel(lat=0).isel(time=slice(-1)).groupby("time.day").plot( col="day", x="lon", sharey=True @@ -2577,5 +3006,7 @@ def test_stacked_groupby_2d_plot(self): self.ds.variable.groupby(id2).plot(col="id", col_wrap=2) def test_groupby_plot_errors(self): - with raises_regex(ValueError, "Expected one of 'row' or 'col' to be 'id'"): + with pytest.raises( + ValueError, match="Expected one of 'row' or 'col' to be 'id'" + ): self.ds.variable.groupby(self.ds.id).plot.line() diff --git a/xarray/tests/test_plugins.py b/xarray/tests/test_plugins.py index 64a1c563dba..b7a5f9405d1 100644 --- a/xarray/tests/test_plugins.py +++ b/xarray/tests/test_plugins.py @@ -45,6 +45,17 @@ def test_remove_duplicates(dummy_duplicated_entrypoints): assert len(entrypoints) == 2 +def test_broken_plugin(): + broken_backend = pkg_resources.EntryPoint.parse( + "broken_backend = xarray.tests.test_plugins:backend_1" + ) + with pytest.warns(RuntimeWarning) as record: + _ = plugins.build_engines([broken_backend]) + assert len(record) == 1 + message = str(record[0].message) + assert "Engine 'broken_backend'" in message + + def test_remove_duplicates_warnings(dummy_duplicated_entrypoints): with pytest.warns(RuntimeWarning) as record: @@ -58,13 +69,13 @@ def test_remove_duplicates_warnings(dummy_duplicated_entrypoints): @mock.patch("pkg_resources.EntryPoint.load", mock.MagicMock(return_value=None)) -def test_create_engines_dict(): +def test_backends_dict_from_pkg(): specs = [ "engine1 = xarray.tests.test_plugins:backend_1", "engine2 = xarray.tests.test_plugins:backend_2", ] entrypoints = [pkg_resources.EntryPoint.parse(spec) for spec in specs] - engines = plugins.create_engines_dict(entrypoints) + engines = plugins.backends_dict_from_pkg(entrypoints) assert len(engines) == 2 assert engines.keys() == set(("engine1", "engine2")) @@ -111,8 +122,62 @@ def test_build_engines(): "cfgrib = xarray.tests.test_plugins:backend_1" ) backend_entrypoints = plugins.build_engines([dummy_pkg_entrypoint]) + assert isinstance(backend_entrypoints["cfgrib"], DummyBackendEntrypoint1) assert backend_entrypoints["cfgrib"].open_dataset_parameters == ( "filename_or_obj", "decoder", ) + + +@mock.patch( + "pkg_resources.EntryPoint.load", + mock.MagicMock(return_value=DummyBackendEntrypoint1), +) +def test_build_engines_sorted(): + dummy_pkg_entrypoints = [ + pkg_resources.EntryPoint.parse( + "dummy2 = xarray.tests.test_plugins:backend_1", + ), + pkg_resources.EntryPoint.parse( + "dummy1 = xarray.tests.test_plugins:backend_1", + ), + ] + backend_entrypoints = plugins.build_engines(dummy_pkg_entrypoints) + backend_entrypoints = list(backend_entrypoints) + + indices = [] + for be in plugins.STANDARD_BACKENDS_ORDER: + try: + index = backend_entrypoints.index(be) + backend_entrypoints.pop(index) + indices.append(index) + except ValueError: + pass + + assert set(indices) < {0, -1} + assert list(backend_entrypoints) == sorted(backend_entrypoints) + + +@mock.patch( + "xarray.backends.plugins.list_engines", + mock.MagicMock(return_value={"dummy": DummyBackendEntrypointArgs()}), +) +def test_no_matching_engine_found(): + with pytest.raises(ValueError, match=r"did not find a match in any"): + plugins.guess_engine("not-valid") + + with pytest.raises(ValueError, match=r"found the following matches with the input"): + plugins.guess_engine("foo.nc") + + +@mock.patch( + "xarray.backends.plugins.list_engines", + mock.MagicMock(return_value={}), +) +def test_engines_not_installed(): + with pytest.raises(ValueError, match=r"xarray is unable to open"): + plugins.guess_engine("not-valid") + + with pytest.raises(ValueError, match=r"found the following matches with the input"): + plugins.guess_engine("foo.nc") diff --git a/xarray/tests/test_sparse.py b/xarray/tests/test_sparse.py index 49b6a58694e..7401b15da42 100644 --- a/xarray/tests/test_sparse.py +++ b/xarray/tests/test_sparse.py @@ -8,19 +8,14 @@ import xarray as xr import xarray.ufuncs as xu from xarray import DataArray, Variable -from xarray.core.npcompat import IS_NEP18_ACTIVE from xarray.core.pycompat import sparse_array_type from . import assert_equal, assert_identical, requires_dask +filterwarnings = pytest.mark.filterwarnings param = pytest.param xfail = pytest.mark.xfail -if not IS_NEP18_ACTIVE: - pytest.skip( - "NUMPY_EXPERIMENTAL_ARRAY_FUNCTION is not enabled", allow_module_level=True - ) - sparse = pytest.importorskip("sparse") @@ -124,12 +119,18 @@ def test_variable_property(prop): param( do("argmax"), True, - marks=xfail(reason="Missing implementation for np.argmin"), + marks=[ + xfail(reason="Missing implementation for np.argmin"), + filterwarnings("ignore:Behaviour of argmin/argmax"), + ], ), param( do("argmin"), True, - marks=xfail(reason="Missing implementation for np.argmax"), + marks=[ + xfail(reason="Missing implementation for np.argmax"), + filterwarnings("ignore:Behaviour of argmin/argmax"), + ], ), param( do("argsort"), @@ -227,6 +228,10 @@ def test_variable_method(func, sparse_output): ret_s = func(var_s) ret_d = func(var_d) + # TODO: figure out how to verify the results of each method + if isinstance(ret_d, xr.Variable) and isinstance(ret_d.data, sparse.SparseArray): + ret_d = ret_d.copy(data=ret_d.data.todense()) + if sparse_output: assert isinstance(ret_s.data, sparse.SparseArray) assert np.allclose(ret_s.data.todense(), ret_d.data, equal_nan=True) @@ -375,12 +380,18 @@ def test_dataarray_property(prop): param( do("argmax"), True, - marks=xfail(reason="Missing implementation for np.argmax"), + marks=[ + xfail(reason="Missing implementation for np.argmax"), + filterwarnings("ignore:Behaviour of argmin/argmax"), + ], ), param( do("argmin"), True, - marks=xfail(reason="Missing implementation for np.argmin"), + marks=[ + xfail(reason="Missing implementation for np.argmin"), + filterwarnings("ignore:Behaviour of argmin/argmax"), + ], ), param( do("argsort"), diff --git a/xarray/tests/test_testing.py b/xarray/tests/test_testing.py index 30ea6aaaee9..dc1db4dc8d7 100644 --- a/xarray/tests/test_testing.py +++ b/xarray/tests/test_testing.py @@ -1,8 +1,9 @@ +import warnings + import numpy as np import pytest import xarray as xr -from xarray.core.npcompat import IS_NEP18_ACTIVE from . import has_dask @@ -97,10 +98,6 @@ def test_assert_duckarray_equal_failing(duckarray, obj1, obj2): pytest.param( np.array, id="numpy", - marks=pytest.mark.skipif( - not IS_NEP18_ACTIVE, - reason="NUMPY_EXPERIMENTAL_ARRAY_FUNCTION is not enabled", - ), ), pytest.param( dask_from_array, @@ -127,3 +124,41 @@ def test_assert_duckarray_equal(duckarray, obj1, obj2): b = duckarray(obj2) xr.testing.assert_duckarray_equal(a, b) + + +@pytest.mark.parametrize( + "func", + [ + "assert_equal", + "assert_identical", + "assert_allclose", + "assert_duckarray_equal", + "assert_duckarray_allclose", + ], +) +def test_ensure_warnings_not_elevated(func): + # make sure warnings are not elevated to errors in the assertion functions + # e.g. by @pytest.mark.filterwarnings("error") + # see https://github.com/pydata/xarray/pull/4760#issuecomment-774101639 + + # define a custom Variable class that raises a warning in assert_* + class WarningVariable(xr.Variable): + @property # type: ignore[misc] + def dims(self): + warnings.warn("warning in test") + return super().dims + + def __array__(self): + warnings.warn("warning in test") + return super().__array__() + + a = WarningVariable("x", [1]) + b = WarningVariable("x", [2]) + + with warnings.catch_warnings(record=True) as w: + # elevate warnings to errors + warnings.filterwarnings("error") + with pytest.raises(AssertionError): + getattr(xr.testing, func)(a, b) + + assert len(w) > 0 diff --git a/xarray/tests/test_tutorial.py b/xarray/tests/test_tutorial.py index a2eb159f624..225fda08f68 100644 --- a/xarray/tests/test_tutorial.py +++ b/xarray/tests/test_tutorial.py @@ -1,6 +1,3 @@ -import os -from contextlib import suppress - import pytest from xarray import DataArray, tutorial @@ -13,20 +10,31 @@ class TestLoadDataset: @pytest.fixture(autouse=True) def setUp(self): self.testfile = "tiny" - self.testfilepath = os.path.expanduser( - os.sep.join(("~", ".xarray_tutorial_data", self.testfile)) - ) - with suppress(OSError): - os.remove(f"{self.testfilepath}.nc") - with suppress(OSError): - os.remove(f"{self.testfilepath}.md5") - - def test_download_from_github(self): - ds = tutorial.open_dataset(self.testfile).load() + + def test_download_from_github(self, tmp_path): + cache_dir = tmp_path / tutorial._default_cache_dir_name + ds = tutorial.open_dataset(self.testfile, cache_dir=cache_dir).load() tiny = DataArray(range(5), name="tiny").to_dataset() assert_identical(ds, tiny) - def test_download_from_github_load_without_cache(self): - ds_nocache = tutorial.open_dataset(self.testfile, cache=False).load() - ds_cache = tutorial.open_dataset(self.testfile).load() + def test_download_from_github_load_without_cache(self, tmp_path, monkeypatch): + cache_dir = tmp_path / tutorial._default_cache_dir_name + + ds_nocache = tutorial.open_dataset( + self.testfile, cache=False, cache_dir=cache_dir + ).load() + ds_cache = tutorial.open_dataset(self.testfile, cache_dir=cache_dir).load() assert_identical(ds_cache, ds_nocache) + + def test_download_rasterio_from_github_load_without_cache( + self, tmp_path, monkeypatch + ): + cache_dir = tmp_path / tutorial._default_cache_dir_name + + arr_nocache = tutorial.open_rasterio( + "RGB.byte", cache=False, cache_dir=cache_dir + ).load() + arr_cache = tutorial.open_rasterio( + "RGB.byte", cache=True, cache_dir=cache_dir + ).load() + assert_identical(arr_cache, arr_nocache) diff --git a/xarray/tests/test_ufuncs.py b/xarray/tests/test_ufuncs.py index 26241152dfa..e8c3af4518f 100644 --- a/xarray/tests/test_ufuncs.py +++ b/xarray/tests/test_ufuncs.py @@ -8,7 +8,7 @@ from . import assert_array_equal from . import assert_identical as assert_identical_ -from . import mock, raises_regex +from . import mock def assert_identical(a, b): @@ -79,7 +79,7 @@ def test_groupby(): assert_identical(ds.a, np.maximum(arr_grouped, group_mean.a)) assert_identical(ds.a, np.maximum(group_mean.a, arr_grouped)) - with raises_regex(ValueError, "mismatched lengths for dimension"): + with pytest.raises(ValueError, match=r"mismatched lengths for dimension"): np.maximum(ds.a.variable, ds_grouped) @@ -136,7 +136,7 @@ def test_dask_defers_to_xarray(): def test_gufunc_methods(): xarray_obj = xr.DataArray([1, 2, 3]) - with raises_regex(NotImplementedError, "reduce method"): + with pytest.raises(NotImplementedError, match=r"reduce method"): np.add.reduce(xarray_obj, 1) @@ -144,7 +144,7 @@ def test_out(): xarray_obj = xr.DataArray([1, 2, 3]) # xarray out arguments should raise - with raises_regex(NotImplementedError, "`out` argument"): + with pytest.raises(NotImplementedError, match=r"`out` argument"): np.add(xarray_obj, 1, out=xarray_obj) # but non-xarray should be OK @@ -156,7 +156,7 @@ def test_out(): def test_gufuncs(): xarray_obj = xr.DataArray([1, 2, 3]) fake_gufunc = mock.Mock(signature="(n)->()", autospec=np.sin) - with raises_regex(NotImplementedError, "generalized ufuncs"): + with pytest.raises(NotImplementedError, match=r"generalized ufuncs"): xarray_obj.__array_ufunc__(fake_gufunc, "__call__", xarray_obj) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 76dd830de23..2140047f38e 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -5,11 +5,22 @@ import pandas as pd import pytest -import xarray as xr -from xarray.core import dtypes -from xarray.core.npcompat import IS_NEP18_ACTIVE +try: + import matplotlib.pyplot as plt +except ImportError: + pass -from . import assert_allclose, assert_duckarray_allclose, assert_equal, assert_identical +import xarray as xr +from xarray.core import dtypes, duck_array_ops + +from . import ( + assert_allclose, + assert_duckarray_allclose, + assert_equal, + assert_identical, + requires_matplotlib, +) +from .test_plot import PlotTestCase from .test_variable import _PAD_XR_NP_ARGS pint = pytest.importorskip("pint") @@ -23,9 +34,6 @@ pytestmark = [ - pytest.mark.skipif( - not IS_NEP18_ACTIVE, reason="NUMPY_EXPERIMENTAL_ARRAY_FUNCTION is not enabled" - ), pytest.mark.filterwarnings("error::pint.UnitStrippedWarning"), ] @@ -280,13 +288,13 @@ class method: This is works a bit similar to using `partial(Class.method, arg, kwarg)` """ - def __init__(self, name, *args, **kwargs): + def __init__(self, name, *args, fallback_func=None, **kwargs): self.name = name + self.fallback = fallback_func self.args = args self.kwargs = kwargs def __call__(self, obj, *args, **kwargs): - from collections.abc import Callable from functools import partial all_args = merge_args(self.args, args) @@ -302,21 +310,23 @@ def __call__(self, obj, *args, **kwargs): if not isinstance(obj, xarray_classes): # remove typical xarray args like "dim" exclude_kwargs = ("dim", "dims") + # TODO: figure out a way to replace dim / dims with axis all_kwargs = { key: value for key, value in all_kwargs.items() if key not in exclude_kwargs } - - func = getattr(obj, self.name, None) - - if func is None or not isinstance(func, Callable): - # fall back to module level numpy functions if not a xarray object - if not isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)): - numpy_func = getattr(np, self.name) - func = partial(numpy_func, obj) + if self.fallback is not None: + func = partial(self.fallback, obj) else: - raise AttributeError(f"{obj} has no method named '{self.name}'") + func = getattr(obj, self.name, None) + + if func is None or not callable(func): + # fall back to module level numpy functions + numpy_func = getattr(np, self.name) + func = partial(numpy_func, obj) + else: + func = getattr(obj, self.name) return func(*all_args, **all_kwargs) @@ -3666,6 +3676,65 @@ def test_stacking_reordering(self, func, dtype): assert_units_equal(expected, actual) assert_identical(expected, actual) + @pytest.mark.parametrize( + "variant", + ( + pytest.param( + "dims", marks=pytest.mark.skip(reason="indexes don't support units") + ), + "coords", + ), + ) + @pytest.mark.parametrize( + "func", + ( + method("differentiate", fallback_func=np.gradient), + method("integrate", fallback_func=duck_array_ops.cumulative_trapezoid), + method("cumulative_integrate", fallback_func=duck_array_ops.trapz), + ), + ids=repr, + ) + def test_differentiate_integrate(self, func, variant, dtype): + data_unit = unit_registry.m + unit = unit_registry.s + + variants = { + "dims": ("x", unit, 1), + "coords": ("u", 1, unit), + } + coord, dim_unit, coord_unit = variants.get(variant) + + array = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit + + x = np.arange(array.shape[0]) * dim_unit + y = np.arange(array.shape[1]) * dim_unit + + u = np.linspace(0, 1, array.shape[0]) * coord_unit + + data_array = xr.DataArray( + data=array, coords={"x": x, "y": y, "u": ("x", u)}, dims=("x", "y") + ) + # we want to make sure the output unit is correct + units = extract_units(data_array) + units.update( + extract_units( + func( + data_array.data, + getattr(data_array, coord).data, + axis=0, + ) + ) + ) + + expected = attach_units( + func(strip_units(data_array), coord=strip_units(coord)), + units, + ) + actual = func(data_array, coord=coord) + + assert_units_equal(expected, actual) + assert_identical(expected, actual) + @pytest.mark.parametrize( "variant", ( @@ -3680,8 +3749,6 @@ def test_stacking_reordering(self, func, dtype): "func", ( method("diff", dim="x"), - method("differentiate", coord="x"), - method("integrate", coord="x"), method("quantile", q=[0.25, 0.75]), method("reduce", func=np.sum, dim="x"), pytest.param(lambda x: x.dot(x), id="method_dot"), @@ -3972,35 +4039,6 @@ def test_repr(self, func, variant, dtype): @pytest.mark.parametrize( "func", ( - function("all"), - function("any"), - pytest.param( - function("argmax"), - marks=pytest.mark.skip( - reason="calling np.argmax as a function on xarray objects is not " - "supported" - ), - ), - pytest.param( - function("argmin"), - marks=pytest.mark.skip( - reason="calling np.argmin as a function on xarray objects is not " - "supported" - ), - ), - function("max"), - function("min"), - function("mean"), - pytest.param( - function("median"), - marks=pytest.mark.xfail(reason="median does not work with dataset yet"), - ), - function("sum"), - function("prod"), - function("std"), - function("var"), - function("cumsum"), - function("cumprod"), method("all"), method("any"), method("argmax", dim="x"), @@ -5538,3 +5576,29 @@ def test_merge(self, variant, unit, error, dtype): assert_units_equal(expected, actual) assert_equal(expected, actual) + + +@requires_matplotlib +class TestPlots(PlotTestCase): + def test_units_in_line_plot_labels(self): + arr = np.linspace(1, 10, 3) * unit_registry.Pa + # TODO make coord a Quantity once unit-aware indexes supported + x_coord = xr.DataArray( + np.linspace(1, 3, 3), dims="x", attrs={"units": "meters"} + ) + da = xr.DataArray(data=arr, dims="x", coords={"x": x_coord}, name="pressure") + + da.plot.line() + + ax = plt.gca() + assert ax.get_ylabel() == "pressure [pascal]" + assert ax.get_xlabel() == "x [meters]" + + def test_units_in_2d_plot_labels(self): + arr = np.ones((2, 3)) * unit_registry.Pa + da = xr.DataArray(data=arr, dims=["x", "y"], name="pressure") + + fig, (ax, cax) = plt.subplots(1, 2) + ax = da.plot.contourf(ax=ax, cbar_ax=cax, add_colorbar=True) + + assert cax.get_ylabel() == "pressure [pascal]" diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index 193c45f01cd..ce796e9de49 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -7,9 +7,9 @@ from xarray.coding.cftimeindex import CFTimeIndex from xarray.core import duck_array_ops, utils -from xarray.core.utils import either_dict_or_kwargs +from xarray.core.utils import either_dict_or_kwargs, iterate_nested -from . import assert_array_equal, raises_regex, requires_cftime, requires_dask +from . import assert_array_equal, requires_cftime, requires_dask from .test_coding_times import _all_cftime_date_types @@ -153,9 +153,9 @@ def test_compat_dict_intersection(self): def test_compat_dict_union(self): assert {"a": "A", "b": "B", "c": "C"} == utils.compat_dict_union(self.x, self.y) - with raises_regex( + with pytest.raises( ValueError, - "unsafe to merge dictionaries without " + match=r"unsafe to merge dictionaries without " "overriding values; conflicting key", ): utils.compat_dict_union(self.x, self.z) @@ -199,12 +199,6 @@ def test_frozen(self): "Frozen({'b': 'B', 'a': 'A'})", ) - def test_sorted_keys_dict(self): - x = {"a": 1, "b": 2, "c": 3} - y = utils.SortedKeysDict(x) - assert list(y) == ["a", "b", "c"] - assert repr(utils.SortedKeysDict()) == "SortedKeysDict({})" - def test_repr_object(): obj = utils.ReprObject("foo") @@ -233,15 +227,6 @@ def test_is_remote_uri(): assert not utils.is_remote_uri("example.nc") -def test_is_grib_path(): - assert not utils.is_grib_path("example.nc") - assert not utils.is_grib_path("example.grib ") - assert utils.is_grib_path("example.grib") - assert utils.is_grib_path("example.grib2") - assert utils.is_grib_path("example.grb") - assert utils.is_grib_path("example.grb2") - - class Test_is_uniform_and_sorted: def test_sorted_uniform(self): assert utils.is_uniform_spaced(np.arange(5)) @@ -330,3 +315,18 @@ def test_infix_dims(supplied, all_, expected): def test_infix_dims_errors(supplied, all_): with pytest.raises(ValueError): list(utils.infix_dims(supplied, all_)) + + +@pytest.mark.parametrize( + "nested_list, expected", + [ + ([], []), + ([1], [1]), + ([1, 2, 3], [1, 2, 3]), + ([[1]], [1]), + ([[1, 2], [3, 4]], [1, 2, 3, 4]), + ([[[1, 2, 3], [4]], [5, 6]], [1, 2, 3, 4, 5, 6]), + ], +) +def test_iterate_nested(nested_list, expected): + assert list(iterate_nested(nested_list)) == expected diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index e1ae3e1f258..7f3ba9123d9 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -8,18 +8,18 @@ import pytest import pytz -from xarray import Coordinate, Dataset, IndexVariable, Variable, set_options +from xarray import Coordinate, DataArray, Dataset, IndexVariable, Variable, set_options from xarray.core import dtypes, duck_array_ops, indexing from xarray.core.common import full_like, ones_like, zeros_like from xarray.core.indexing import ( BasicIndexer, CopyOnWriteArray, DaskIndexingAdapter, - LazilyOuterIndexedArray, + LazilyIndexedArray, MemoryCachedArray, NumpyIndexingAdapter, OuterIndexer, - PandasIndexAdapter, + PandasIndexingAdapter, VectorizedIndexer, ) from xarray.core.pycompat import dask_array_type @@ -32,8 +32,10 @@ assert_array_equal, assert_equal, assert_identical, - raises_regex, + raise_if_dask_computes, + requires_cupy, requires_dask, + requires_pint_0_15, requires_sparse, source_ndarray, ) @@ -47,6 +49,11 @@ ] +@pytest.fixture +def var(): + return Variable(dims=list("xyz"), data=np.random.rand(3, 4, 5)) + + class VariableSubclassobjects: def test_properties(self): data = 0.5 * np.arange(10) @@ -119,7 +126,7 @@ def test_getitem_1d_fancy(self): v_new = v[[True, False, True]] assert_identical(v[[0, 2]], v_new) - with raises_regex(IndexError, "Boolean indexer should"): + with pytest.raises(IndexError, match=r"Boolean indexer should"): ind = Variable(("a",), [True, False, True]) v[ind] @@ -297,14 +304,14 @@ def test_object_conversion(self): def test_datetime64_valid_range(self): data = np.datetime64("1250-01-01", "us") pderror = pd.errors.OutOfBoundsDatetime - with raises_regex(pderror, "Out of bounds nanosecond"): + with pytest.raises(pderror, match=r"Out of bounds nanosecond"): self.cls(["t"], [data]) @pytest.mark.xfail(reason="pandas issue 36615") def test_timedelta64_valid_range(self): data = np.timedelta64("200000", "D") pderror = pd.errors.OutOfBoundsTimedelta - with raises_regex(pderror, "Out of bounds nanosecond"): + with pytest.raises(pderror, match=r"Out of bounds nanosecond"): self.cls(["t"], [data]) def test_pandas_data(self): @@ -454,7 +461,7 @@ def test_concat(self): assert_identical( Variable(["b", "a"], np.array([x, y])), Variable.concat((v, w), "b") ) - with raises_regex(ValueError, "Variable has dimensions"): + with pytest.raises(ValueError, match=r"Variable has dimensions"): Variable.concat([v, Variable(["c"], y)], "b") # test indexers actual = Variable.concat( @@ -469,7 +476,7 @@ def test_concat(self): assert_identical(v, Variable.concat([v[:1], v[1:]], "time")) # test dimension order assert_identical(v, Variable.concat([v[:, :5], v[:, 5:]], "x")) - with raises_regex(ValueError, "all input arrays must have"): + with pytest.raises(ValueError, match=r"all input arrays must have"): Variable.concat([v[:, 0], v[:, 1:]], "x") def test_concat_attrs(self): @@ -530,7 +537,7 @@ def test_copy_index(self): v = self.cls("x", midx) for deep in [True, False]: w = v.copy(deep=deep) - assert isinstance(w._data, PandasIndexAdapter) + assert isinstance(w._data, PandasIndexingAdapter) assert isinstance(w.to_index(), pd.MultiIndex) assert_array_equal(v._data.array, w._data.array) @@ -545,7 +552,7 @@ def test_copy_with_data(self): def test_copy_with_data_errors(self): orig = Variable(("x", "y"), [[1.5, 2.0], [3.1, 4.3]], {"foo": "bar"}) new_data = [2.5, 5.0] - with raises_regex(ValueError, "must match shape of object"): + with pytest.raises(ValueError, match=r"must match shape of object"): orig.copy(data=new_data) def test_copy_index_with_data(self): @@ -558,11 +565,11 @@ def test_copy_index_with_data(self): def test_copy_index_with_data_errors(self): orig = IndexVariable("x", np.arange(5)) new_data = np.arange(5, 20) - with raises_regex(ValueError, "must match shape of object"): + with pytest.raises(ValueError, match=r"must match shape of object"): orig.copy(data=new_data) - with raises_regex(ValueError, "Cannot assign to the .data"): + with pytest.raises(ValueError, match=r"Cannot assign to the .data"): orig.data = new_data - with raises_regex(ValueError, "Cannot assign to the .values"): + with pytest.raises(ValueError, match=r"Cannot assign to the .values"): orig.values = new_data def test_replace(self): @@ -658,12 +665,12 @@ def test_getitem_advanced(self): # with boolean variable with wrong shape ind = np.array([True, False]) - with raises_regex(IndexError, "Boolean array size 2 is "): + with pytest.raises(IndexError, match=r"Boolean array size 2 is "): v[Variable(("a", "b"), [[0, 1]]), ind] # boolean indexing with different dimension ind = Variable(["a"], [True, False, False]) - with raises_regex(IndexError, "Boolean indexer should be"): + with pytest.raises(IndexError, match=r"Boolean indexer should be"): v[dict(y=ind)] def test_getitem_uint_1d(self): @@ -793,21 +800,21 @@ def test_getitem_fancy(self): def test_getitem_error(self): v = self.cls(["x", "y"], [[0, 1, 2], [3, 4, 5]]) - with raises_regex(IndexError, "labeled multi-"): + with pytest.raises(IndexError, match=r"labeled multi-"): v[[[0, 1], [1, 2]]] ind_x = Variable(["a"], [0, 1, 1]) ind_y = Variable(["a"], [0, 1]) - with raises_regex(IndexError, "Dimensions of indexers "): + with pytest.raises(IndexError, match=r"Dimensions of indexers "): v[ind_x, ind_y] ind = Variable(["a", "b"], [[True, False], [False, True]]) - with raises_regex(IndexError, "2-dimensional boolean"): + with pytest.raises(IndexError, match=r"2-dimensional boolean"): v[dict(x=ind)] v = Variable(["x", "y", "z"], np.arange(60).reshape(3, 4, 5)) ind = Variable(["x"], [0, 1]) - with raises_regex(IndexError, "Dimensions of indexers mis"): + with pytest.raises(IndexError, match=r"Dimensions of indexers mis"): v[:, ind] @pytest.mark.parametrize( @@ -873,26 +880,100 @@ def test_pad_constant_values(self, xr_arg, np_arg): ) assert_array_equal(actual, expected) - def test_rolling_window(self): + @pytest.mark.parametrize("d, w", (("x", 3), ("y", 5))) + def test_rolling_window(self, d, w): # Just a working test. See test_nputils for the algorithm validation v = self.cls(["x", "y", "z"], np.arange(40 * 30 * 2).reshape(40, 30, 2)) - for (d, w) in [("x", 3), ("y", 5)]: - v_rolling = v.rolling_window(d, w, d + "_window") - assert v_rolling.dims == ("x", "y", "z", d + "_window") - assert v_rolling.shape == v.shape + (w,) + v_rolling = v.rolling_window(d, w, d + "_window") + assert v_rolling.dims == ("x", "y", "z", d + "_window") + assert v_rolling.shape == v.shape + (w,) + + v_rolling = v.rolling_window(d, w, d + "_window", center=True) + assert v_rolling.dims == ("x", "y", "z", d + "_window") + assert v_rolling.shape == v.shape + (w,) + + # dask and numpy result should be the same + v_loaded = v.load().rolling_window(d, w, d + "_window", center=True) + assert_array_equal(v_rolling, v_loaded) + + # numpy backend should not be over-written + if isinstance(v._data, np.ndarray): + with pytest.raises(ValueError): + v_loaded[0] = 1.0 + + def test_rolling_1d(self): + x = self.cls("x", np.array([1, 2, 3, 4], dtype=float)) + + kwargs = dict(dim="x", window=3, window_dim="xw") + actual = x.rolling_window(**kwargs, center=True, fill_value=np.nan) + expected = Variable( + ("x", "xw"), + np.array( + [[np.nan, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, np.nan]], dtype=float + ), + ) + assert_equal(actual, expected) - v_rolling = v.rolling_window(d, w, d + "_window", center=True) - assert v_rolling.dims == ("x", "y", "z", d + "_window") - assert v_rolling.shape == v.shape + (w,) + actual = x.rolling_window(**kwargs, center=False, fill_value=0.0) + expected = self.cls( + ("x", "xw"), + np.array([[0, 0, 1], [0, 1, 2], [1, 2, 3], [2, 3, 4]], dtype=float), + ) + assert_equal(actual, expected) - # dask and numpy result should be the same - v_loaded = v.load().rolling_window(d, w, d + "_window", center=True) - assert_array_equal(v_rolling, v_loaded) + x = self.cls(("y", "x"), np.stack([x, x * 1.1])) + actual = x.rolling_window(**kwargs, center=False, fill_value=0.0) + expected = self.cls( + ("y", "x", "xw"), np.stack([expected.data, expected.data * 1.1], axis=0) + ) + assert_equal(actual, expected) - # numpy backend should not be over-written - if isinstance(v._data, np.ndarray): - with pytest.raises(ValueError): - v_loaded[0] = 1.0 + @pytest.mark.parametrize("center", [[True, True], [False, False]]) + @pytest.mark.parametrize("dims", [("x", "y"), ("y", "z"), ("z", "x")]) + def test_nd_rolling(self, center, dims): + x = self.cls( + ("x", "y", "z"), + np.arange(7 * 6 * 8).reshape(7, 6, 8).astype(float), + ) + window = [3, 3] + actual = x.rolling_window( + dim=dims, + window=window, + window_dim=[f"{k}w" for k in dims], + center=center, + fill_value=np.nan, + ) + expected = x + for dim, win, cent in zip(dims, window, center): + expected = expected.rolling_window( + dim=dim, + window=win, + window_dim=f"{dim}w", + center=cent, + fill_value=np.nan, + ) + assert_equal(actual, expected) + + @pytest.mark.parametrize( + ("dim, window, window_dim, center"), + [ + ("x", [3, 3], "x_w", True), + ("x", 3, ("x_w", "x_w"), True), + ("x", 3, "x_w", [True, True]), + ], + ) + def test_rolling_window_errors(self, dim, window, window_dim, center): + x = self.cls( + ("x", "y", "z"), + np.arange(7 * 6 * 8).reshape(7, 6, 8).astype(float), + ) + with pytest.raises(ValueError): + x.rolling_window( + dim=dim, + window=window, + window_dim=window_dim, + center=center, + ) class TestVariable(VariableSubclassobjects): @@ -1054,11 +1135,11 @@ def test_as_variable(self): ) assert_identical(expected_extra, as_variable(xarray_tuple)) - with raises_regex(TypeError, "tuple of form"): + with pytest.raises(TypeError, match=r"tuple of form"): as_variable(tuple(data)) - with raises_regex(ValueError, "tuple of form"): # GH1016 + with pytest.raises(ValueError, match=r"tuple of form"): # GH1016 as_variable(("five", "six", "seven")) - with raises_regex(TypeError, "without an explicit list of dimensions"): + with pytest.raises(TypeError, match=r"without an explicit list of dimensions"): as_variable(data) actual = as_variable(data, name="x") @@ -1070,9 +1151,9 @@ def test_as_variable(self): data = np.arange(9).reshape((3, 3)) expected = Variable(("x", "y"), data) - with raises_regex(ValueError, "without explicit dimension names"): + with pytest.raises(ValueError, match=r"without explicit dimension names"): as_variable(data, name="x") - with raises_regex(ValueError, "has more than 1-dimension"): + with pytest.raises(ValueError, match=r"has more than 1-dimension"): as_variable(expected, name="x") # test datetime, timedelta conversion @@ -1081,6 +1162,9 @@ def test_as_variable(self): td = np.array([timedelta(days=x) for x in range(10)]) assert as_variable(td, "time").dtype.kind == "m" + with pytest.raises(TypeError): + as_variable(("x", DataArray([]))) + def test_repr(self): v = Variable(["time", "x"], [[1, 2, 3], [4, 5, 6]], {"foo": "bar"}) expected = dedent( @@ -1095,12 +1179,12 @@ def test_repr(self): assert expected == repr(v) def test_repr_lazy_data(self): - v = Variable("x", LazilyOuterIndexedArray(np.arange(2e5))) + v = Variable("x", LazilyIndexedArray(np.arange(2e5))) assert "200000 values with dtype" in repr(v) - assert isinstance(v._data, LazilyOuterIndexedArray) + assert isinstance(v._data, LazilyIndexedArray) def test_detect_indexer_type(self): - """ Tests indexer type was correctly detected. """ + """Tests indexer type was correctly detected.""" data = np.random.random((10, 11)) v = Variable(["x", "y"], data) @@ -1197,7 +1281,7 @@ def test_items(self): # test iteration for n, item in enumerate(v): assert_identical(Variable(["y"], data[n]), item) - with raises_regex(TypeError, "iteration over a 0-d"): + with pytest.raises(TypeError, match=r"iteration over a 0-d"): iter(Variable([], 0)) # test setting v.values[:] = 0 @@ -1271,9 +1355,9 @@ def test_isel(self): assert_identical(v.isel(x=0), v[:, 0]) assert_identical(v.isel(x=[0, 2]), v[:, [0, 2]]) assert_identical(v.isel(time=[]), v[[]]) - with raises_regex( + with pytest.raises( ValueError, - r"Dimensions {'not_a_dim'} do not exist. Expected one or more of " + match=r"Dimensions {'not_a_dim'} do not exist. Expected one or more of " r"\('time', 'x'\)", ): v.isel(not_a_dim=0) @@ -1326,7 +1410,7 @@ def test_shift(self, fill_value): assert_identical(expected, v.shift(x=5, fill_value=fill_value)) assert_identical(expected, v.shift(x=6, fill_value=fill_value)) - with raises_regex(ValueError, "dimension"): + with pytest.raises(ValueError, match=r"dimension"): v.shift(z=0) v = Variable("x", [1, 2, 3, 4, 5], {"foo": "bar"}) @@ -1355,7 +1439,7 @@ def test_roll(self): assert_identical(expected, v.roll(x=2)) assert_identical(expected, v.roll(x=-3)) - with raises_regex(ValueError, "dimension"): + with pytest.raises(ValueError, match=r"dimension"): v.roll(z=0) def test_roll_consistency(self): @@ -1384,6 +1468,20 @@ def test_transpose(self): w3 = Variable(["b", "c", "d", "a"], np.einsum("abcd->bcda", x)) assert_identical(w, w3.transpose("a", "b", "c", "d")) + # test missing dimension, raise error + with pytest.raises(ValueError): + v.transpose(..., "not_a_dim") + + # test missing dimension, ignore error + actual = v.transpose(..., "not_a_dim", missing_dims="ignore") + expected_ell = v.transpose(...) + assert_identical(expected_ell, actual) + + # test missing dimension, raise warning + with pytest.warns(UserWarning): + v.transpose(..., "not_a_dim", missing_dims="warn") + assert_identical(expected_ell, actual) + def test_transpose_0d(self): for value in [ 3.5, @@ -1408,7 +1506,7 @@ def test_squeeze(self): v = Variable(["x", "y"], [[1, 2]]) assert_identical(Variable(["y"], [1, 2]), v.squeeze()) assert_identical(Variable(["y"], [1, 2]), v.squeeze("x")) - with raises_regex(ValueError, "cannot select a dimension"): + with pytest.raises(ValueError, match=r"cannot select a dimension"): v.squeeze("y") def test_get_axis_num(self): @@ -1417,7 +1515,7 @@ def test_get_axis_num(self): assert v.get_axis_num(["x"]) == (0,) assert v.get_axis_num(["x", "y"]) == (0, 1) assert v.get_axis_num(["z", "y", "x"]) == (2, 1, 0) - with raises_regex(ValueError, "not found in array dim"): + with pytest.raises(ValueError, match=r"not found in array dim"): v.get_axis_num("foobar") def test_set_dims(self): @@ -1438,7 +1536,7 @@ def test_set_dims(self): expected = v assert_identical(actual, expected) - with raises_regex(ValueError, "must be a superset"): + with pytest.raises(ValueError, match=r"must be a superset"): v.set_dims(["z"]) def test_set_dims_object_dtype(self): @@ -1470,9 +1568,9 @@ def test_stack(self): def test_stack_errors(self): v = Variable(["x", "y"], [[0, 1], [2, 3]], {"foo": "bar"}) - with raises_regex(ValueError, "invalid existing dim"): + with pytest.raises(ValueError, match=r"invalid existing dim"): v.stack(z=("x1",)) - with raises_regex(ValueError, "cannot create a new dim"): + with pytest.raises(ValueError, match=r"cannot create a new dim"): v.stack(x=("x",)) def test_unstack(self): @@ -1491,11 +1589,11 @@ def test_unstack(self): def test_unstack_errors(self): v = Variable("z", [0, 1, 2, 3]) - with raises_regex(ValueError, "invalid existing dim"): + with pytest.raises(ValueError, match=r"invalid existing dim"): v.unstack(foo={"x": 4}) - with raises_regex(ValueError, "cannot create a new dim"): + with pytest.raises(ValueError, match=r"cannot create a new dim"): v.stack(z=("z",)) - with raises_regex(ValueError, "the product of the new dim"): + with pytest.raises(ValueError, match=r"the product of the new dim"): v.unstack(z={"x": 5}) def test_unstack_2d(self): @@ -1540,9 +1638,9 @@ def test_broadcasting_failures(self): a = Variable(["x"], np.arange(10)) b = Variable(["x"], np.arange(5)) c = Variable(["x", "x"], np.arange(100).reshape(10, 10)) - with raises_regex(ValueError, "mismatched lengths"): + with pytest.raises(ValueError, match=r"mismatched lengths"): a + b - with raises_regex(ValueError, "duplicate dimensions"): + with pytest.raises(ValueError, match=r"duplicate dimensions"): a + c def test_inplace_math(self): @@ -1555,7 +1653,7 @@ def test_inplace_math(self): assert source_ndarray(v.values) is x assert_array_equal(v.values, np.arange(5) + 1) - with raises_regex(ValueError, "dimensions cannot change"): + with pytest.raises(ValueError, match=r"dimensions cannot change"): v += Variable("y", np.arange(5)) def test_reduce(self): @@ -1572,9 +1670,26 @@ def test_reduce(self): ) assert_allclose(v.mean("x"), v.reduce(np.mean, "x")) - with raises_regex(ValueError, "cannot supply both"): + with pytest.raises(ValueError, match=r"cannot supply both"): v.mean(dim="x", axis=0) + @requires_bottleneck + def test_reduce_use_bottleneck(self, monkeypatch): + def raise_if_called(*args, **kwargs): + raise RuntimeError("should not have been called") + + import bottleneck as bn + + monkeypatch.setattr(bn, "nanmin", raise_if_called) + + v = Variable("x", [0.0, np.nan, 1.0]) + with pytest.raises(RuntimeError, match="should not have been called"): + with set_options(use_bottleneck=True): + v.min() + + with set_options(use_bottleneck=False): + v.min() + @pytest.mark.parametrize("skipna", [True, False]) @pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]]) @pytest.mark.parametrize( @@ -1602,7 +1717,7 @@ def test_quantile_chunked_dim_error(self): v = Variable(["x", "y"], self.d).chunk({"x": 2}) # this checks for ValueError in dask.array.apply_gufunc - with raises_regex(ValueError, "consists of multiple chunks"): + with pytest.raises(ValueError, match=r"consists of multiple chunks"): v.quantile(0.5, dim="x") @pytest.mark.parametrize("q", [-0.1, 1.1, [2], [0.25, 2]]) @@ -1610,16 +1725,24 @@ def test_quantile_out_of_bounds(self, q): v = Variable(["x", "y"], self.d) # escape special characters - with raises_regex(ValueError, r"Quantiles must be in the range \[0, 1\]"): + with pytest.raises( + ValueError, match=r"Quantiles must be in the range \[0, 1\]" + ): v.quantile(q, dim="x") @requires_dask @requires_bottleneck def test_rank_dask_raises(self): v = Variable(["x"], [3.0, 1.0, np.nan, 2.0, 4.0]).chunk(2) - with raises_regex(TypeError, "arrays stored as dask"): + with pytest.raises(TypeError, match=r"arrays stored as dask"): v.rank("x") + def test_rank_use_bottleneck(self): + v = Variable(["x"], [3.0, 1.0, np.nan, 2.0, 4.0]) + with set_options(use_bottleneck=False): + with pytest.raises(RuntimeError): + v.rank("x") + @requires_bottleneck def test_rank(self): import bottleneck as bn @@ -1643,7 +1766,7 @@ def test_rank(self): v_expect = Variable(["x"], [0.75, 0.25, np.nan, 0.5, 1.0]) assert_equal(v.rank("x", pct=True), v_expect) # invalid dim - with raises_regex(ValueError, "not found"): + with pytest.raises(ValueError, match=r"not found"): v.rank("y") def test_big_endian_reduce(self): @@ -1850,7 +1973,7 @@ def assert_assigned_2d(array, key_x, key_y, values): expected = Variable(["x", "y"], [[0, 0], [0, 0], [1, 1]]) assert_identical(expected, v) - with raises_regex(ValueError, "shape mismatch"): + with pytest.raises(ValueError, match=r"shape mismatch"): v[ind, ind] = np.zeros((1, 2, 1)) v = Variable(["x", "y"], [[0, 3, 2], [3, 4, 5]]) @@ -2007,6 +2130,29 @@ def test_getitem_with_mask_nd_indexer(self): self.cls(("x", "y"), [[0, -1], [-1, 2]]), ) + @pytest.mark.parametrize("dim", ["x", "y"]) + @pytest.mark.parametrize("window", [3, 8, 11]) + @pytest.mark.parametrize("center", [True, False]) + def test_dask_rolling(self, dim, window, center): + import dask + import dask.array as da + + dask.config.set(scheduler="single-threaded") + + x = Variable(("x", "y"), np.array(np.random.randn(100, 40), dtype=float)) + dx = Variable(("x", "y"), da.from_array(x, chunks=[(6, 30, 30, 20, 14), 8])) + + expected = x.rolling_window( + dim, window, "window", center=center, fill_value=np.nan + ) + with raise_if_dask_computes(): + actual = dx.rolling_window( + dim, window, "window", center=center, fill_value=np.nan + ) + assert isinstance(actual.data, da.Array) + assert actual.shape == expected.shape + assert_equal(actual, expected) + @requires_sparse class TestVariableWithSparse: @@ -2023,7 +2169,7 @@ class TestIndexVariable(VariableSubclassobjects): cls = staticmethod(IndexVariable) def test_init(self): - with raises_regex(ValueError, "must be 1-dimensional"): + with pytest.raises(ValueError, match=r"must be 1-dimensional"): IndexVariable((), 0) def test_to_index(self): @@ -2038,12 +2184,12 @@ def test_multiindex_default_level_names(self): def test_data(self): x = IndexVariable("x", np.arange(3.0)) - assert isinstance(x._data, PandasIndexAdapter) + assert isinstance(x._data, PandasIndexingAdapter) assert isinstance(x.data, np.ndarray) assert float == x.dtype assert_array_equal(np.arange(3), x) assert float == x.values.dtype - with raises_regex(TypeError, "cannot be modified"): + with pytest.raises(TypeError, match=r"cannot be modified"): x[:] = 0 def test_name(self): @@ -2070,7 +2216,7 @@ def test_get_level_variable(self): level_1 = IndexVariable("x", midx.get_level_values("level_1")) assert_identical(x.get_level_variable("level_1"), level_1) - with raises_regex(ValueError, "has no MultiIndex"): + with pytest.raises(ValueError, match=r"has no MultiIndex"): IndexVariable("y", [10.0]).get_level_variable("level") def test_concat_periods(self): @@ -2118,23 +2264,23 @@ def test_datetime64(self): # These tests make use of multi-dimensional variables, which are not valid # IndexVariable objects: - @pytest.mark.xfail + @pytest.mark.skip def test_getitem_error(self): super().test_getitem_error() - @pytest.mark.xfail + @pytest.mark.skip def test_getitem_advanced(self): super().test_getitem_advanced() - @pytest.mark.xfail + @pytest.mark.skip def test_getitem_fancy(self): super().test_getitem_fancy() - @pytest.mark.xfail + @pytest.mark.skip def test_getitem_uint(self): super().test_getitem_fancy() - @pytest.mark.xfail + @pytest.mark.skip @pytest.mark.parametrize( "mode", [ @@ -2153,23 +2299,34 @@ def test_getitem_uint(self): def test_pad(self, mode, xr_arg, np_arg): super().test_pad(mode, xr_arg, np_arg) - @pytest.mark.xfail - @pytest.mark.parametrize("xr_arg, np_arg", _PAD_XR_NP_ARGS) + @pytest.mark.skip def test_pad_constant_values(self, xr_arg, np_arg): super().test_pad_constant_values(xr_arg, np_arg) - @pytest.mark.xfail + @pytest.mark.skip def test_rolling_window(self): super().test_rolling_window() - @pytest.mark.xfail + @pytest.mark.skip + def test_rolling_1d(self): + super().test_rolling_1d() + + @pytest.mark.skip + def test_nd_rolling(self): + super().test_nd_rolling() + + @pytest.mark.skip + def test_rolling_window_errors(self): + super().test_rolling_window_errors() + + @pytest.mark.skip def test_coarsen_2d(self): super().test_coarsen_2d() class TestAsCompatibleData: def test_unchanged_types(self): - types = (np.asarray, PandasIndexAdapter, LazilyOuterIndexedArray) + types = (np.asarray, PandasIndexingAdapter, LazilyIndexedArray) for t in types: for data in [ np.arange(3), @@ -2242,9 +2399,12 @@ def test_full_like(self): assert_identical(expect, full_like(orig, True, dtype=bool)) # raise error on non-scalar fill_value - with raises_regex(ValueError, "must be scalar"): + with pytest.raises(ValueError, match=r"must be scalar"): full_like(orig, [1.0, 2.0]) + with pytest.raises(ValueError, match="'dtype' cannot be dict-like"): + full_like(orig, True, dtype={"x": bool}) + @requires_dask def test_full_like_dask(self): orig = Variable( @@ -2300,6 +2460,11 @@ def __init__(self, array): class CustomIndexable(CustomArray, indexing.ExplicitlyIndexed): pass + # Type with data stored in values attribute + class CustomWithValuesAttr: + def __init__(self, array): + self.values = array + array = CustomArray(np.arange(3)) orig = Variable(dims=("x"), data=array, attrs={"foo": "bar"}) assert isinstance(orig._data, np.ndarray) # should not be CustomArray @@ -2308,6 +2473,10 @@ class CustomIndexable(CustomArray, indexing.ExplicitlyIndexed): orig = Variable(dims=("x"), data=array, attrs={"foo": "bar"}) assert isinstance(orig._data, CustomIndexable) + array = CustomWithValuesAttr(np.arange(3)) + orig = Variable(dims=(), data=array) + assert isinstance(orig._data.item(), CustomWithValuesAttr) + def test_raise_no_warning_for_nan_in_binary_ops(): with pytest.warns(None) as record: @@ -2316,7 +2485,7 @@ def test_raise_no_warning_for_nan_in_binary_ops(): class TestBackendIndexing: - """ Make sure all the array wrappers can be indexed. """ + """Make sure all the array wrappers can be indexed.""" @pytest.fixture(autouse=True) def setUp(self): @@ -2335,24 +2504,24 @@ def test_NumpyIndexingAdapter(self): self.check_orthogonal_indexing(v) self.check_vectorized_indexing(v) # could not doubly wrapping - with raises_regex(TypeError, "NumpyIndexingAdapter only wraps "): + with pytest.raises(TypeError, match=r"NumpyIndexingAdapter only wraps "): v = Variable( dims=("x", "y"), data=NumpyIndexingAdapter(NumpyIndexingAdapter(self.d)) ) - def test_LazilyOuterIndexedArray(self): - v = Variable(dims=("x", "y"), data=LazilyOuterIndexedArray(self.d)) + def test_LazilyIndexedArray(self): + v = Variable(dims=("x", "y"), data=LazilyIndexedArray(self.d)) self.check_orthogonal_indexing(v) self.check_vectorized_indexing(v) # doubly wrapping v = Variable( dims=("x", "y"), - data=LazilyOuterIndexedArray(LazilyOuterIndexedArray(self.d)), + data=LazilyIndexedArray(LazilyIndexedArray(self.d)), ) self.check_orthogonal_indexing(v) # hierarchical wrapping v = Variable( - dims=("x", "y"), data=LazilyOuterIndexedArray(NumpyIndexingAdapter(self.d)) + dims=("x", "y"), data=LazilyIndexedArray(NumpyIndexingAdapter(self.d)) ) self.check_orthogonal_indexing(v) @@ -2361,9 +2530,7 @@ def test_CopyOnWriteArray(self): self.check_orthogonal_indexing(v) self.check_vectorized_indexing(v) # doubly wrapping - v = Variable( - dims=("x", "y"), data=CopyOnWriteArray(LazilyOuterIndexedArray(self.d)) - ) + v = Variable(dims=("x", "y"), data=CopyOnWriteArray(LazilyIndexedArray(self.d))) self.check_orthogonal_indexing(v) self.check_vectorized_indexing(v) @@ -2388,3 +2555,92 @@ def test_DaskIndexingAdapter(self): v = Variable(dims=("x", "y"), data=CopyOnWriteArray(DaskIndexingAdapter(da))) self.check_orthogonal_indexing(v) self.check_vectorized_indexing(v) + + +def test_clip(var): + # Copied from test_dataarray (would there be a way to combine the tests?) + result = var.clip(min=0.5) + assert result.min(...) >= 0.5 + + result = var.clip(max=0.5) + assert result.max(...) <= 0.5 + + result = var.clip(min=0.25, max=0.75) + assert result.min(...) >= 0.25 + assert result.max(...) <= 0.75 + + result = var.clip(min=var.mean("x"), max=var.mean("z")) + assert result.dims == var.dims + assert_array_equal( + result.data, + np.clip( + var.data, + var.mean("x").data[np.newaxis, :, :], + var.mean("z").data[:, :, np.newaxis], + ), + ) + + +@pytest.mark.parametrize("Var", [Variable, IndexVariable]) +class TestNumpyCoercion: + def test_from_numpy(self, Var): + v = Var("x", [1, 2, 3]) + + assert_identical(v.as_numpy(), v) + np.testing.assert_equal(v.to_numpy(), np.array([1, 2, 3])) + + @requires_dask + def test_from_dask(self, Var): + v = Var("x", [1, 2, 3]) + v_chunked = v.chunk(1) + + assert_identical(v_chunked.as_numpy(), v.compute()) + np.testing.assert_equal(v.to_numpy(), np.array([1, 2, 3])) + + @requires_pint_0_15 + def test_from_pint(self, Var): + from pint import Quantity + + arr = np.array([1, 2, 3]) + v = Var("x", Quantity(arr, units="m")) + + assert_identical(v.as_numpy(), Var("x", arr)) + np.testing.assert_equal(v.to_numpy(), arr) + + @requires_sparse + def test_from_sparse(self, Var): + if Var is IndexVariable: + pytest.skip("Can't have 2D IndexVariables") + + import sparse + + arr = np.diagflat([1, 2, 3]) + sparr = sparse.COO(coords=[[0, 1, 2], [0, 1, 2]], data=[1, 2, 3]) + v = Variable(["x", "y"], sparr) + + assert_identical(v.as_numpy(), Variable(["x", "y"], arr)) + np.testing.assert_equal(v.to_numpy(), arr) + + @requires_cupy + def test_from_cupy(self, Var): + import cupy as cp + + arr = np.array([1, 2, 3]) + v = Var("x", cp.array(arr)) + + assert_identical(v.as_numpy(), Var("x", arr)) + np.testing.assert_equal(v.to_numpy(), arr) + + @requires_dask + @requires_pint_0_15 + def test_from_pint_wrapping_dask(self, Var): + import dask + from pint import Quantity + + arr = np.array([1, 2, 3]) + d = dask.array.from_array(np.array([1, 2, 3])) + v = Var("x", Quantity(d, units="m")) + + result = v.as_numpy() + assert_identical(result, Var("x", arr)) + np.testing.assert_equal(v.to_numpy(), arr) diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index dc79d417b9c..45e662f118e 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -3,7 +3,7 @@ import xarray as xr from xarray import DataArray -from xarray.tests import assert_allclose, assert_equal, raises_regex +from xarray.tests import assert_allclose, assert_equal from . import raise_if_dask_computes, requires_cftime, requires_dask @@ -15,7 +15,7 @@ def test_weighted_non_DataArray_weights(as_dataset): if as_dataset: data = data.to_dataset(name="data") - with raises_regex(ValueError, "`weights` must be a DataArray"): + with pytest.raises(ValueError, match=r"`weights` must be a DataArray"): data.weighted([1, 2]) @@ -368,3 +368,19 @@ def test_weighted_operations_keep_attr_da_in_ds(operation): result = getattr(data.weighted(weights), operation)(keep_attrs=True) assert data.a.attrs == result.a.attrs + + +@pytest.mark.parametrize("as_dataset", (True, False)) +def test_weighted_bad_dim(as_dataset): + + data = DataArray(np.random.randn(2, 2)) + weights = xr.ones_like(data) + if as_dataset: + data = data.to_dataset(name="data") + + error_msg = ( + f"{data.__class__.__name__}Weighted" + " does not contain the dimensions: {'bad_dim'}" + ) + with pytest.raises(ValueError, match=error_msg): + data.weighted(weights).mean("bad_dim") diff --git a/xarray/tutorial.py b/xarray/tutorial.py index 055be36d80b..78471be7a0e 100644 --- a/xarray/tutorial.py +++ b/xarray/tutorial.py @@ -5,33 +5,79 @@ * building tutorials in the documentation. """ -import hashlib -import os as _os -from urllib.request import urlretrieve +import os +import pathlib import numpy as np from .backends.api import open_dataset as _open_dataset +from .backends.rasterio_ import open_rasterio as _open_rasterio from .core.dataarray import DataArray from .core.dataset import Dataset -_default_cache_dir = _os.sep.join(("~", ".xarray_tutorial_data")) - - -def file_md5_checksum(fname): - hash_md5 = hashlib.md5() - with open(fname, "rb") as f: - hash_md5.update(f.read()) - return hash_md5.hexdigest() +_default_cache_dir_name = "xarray_tutorial_data" +base_url = "https://github.com/pydata/xarray-data" +version = "master" + + +def _construct_cache_dir(path): + import pooch + + if isinstance(path, pathlib.Path): + path = os.fspath(path) + elif path is None: + path = pooch.os_cache(_default_cache_dir_name) + + return path + + +external_urls = {} # type: dict +external_rasterio_urls = { + "RGB.byte": "https://github.com/mapbox/rasterio/raw/1.2.1/tests/data/RGB.byte.tif", + "shade": "https://github.com/mapbox/rasterio/raw/1.2.1/tests/data/shade.tif", +} +file_formats = { + "air_temperature": 3, + "rasm": 3, + "ROMS_example": 4, + "tiny": 3, + "eraint_uvz": 3, +} + + +def _check_netcdf_engine_installed(name): + version = file_formats.get(name) + if version == 3: + try: + import scipy # noqa + except ImportError: + try: + import netCDF4 # noqa + except ImportError: + raise ImportError( + f"opening tutorial dataset {name} requires either scipy or " + "netCDF4 to be installed." + ) + if version == 4: + try: + import h5netcdf # noqa + except ImportError: + try: + import netCDF4 # noqa + except ImportError: + raise ImportError( + f"opening tutorial dataset {name} requires either h5netcdf " + "or netCDF4 to be installed." + ) # idea borrowed from Seaborn def open_dataset( name, cache=True, - cache_dir=_default_cache_dir, - github_url="https://github.com/pydata/xarray-data", - branch="master", + cache_dir=None, + *, + engine=None, **kws, ): """ @@ -39,68 +85,134 @@ def open_dataset( If a local copy is found then always use that to avoid network traffic. + Available datasets: + + * ``"air_temperature"``: NCEP reanalysis subset + * ``"rasm"``: Output of the Regional Arctic System Model (RASM) + * ``"ROMS_example"``: Regional Ocean Model System (ROMS) output + * ``"tiny"``: small synthetic dataset with a 1D data variable + * ``"era5-2mt-2019-03-uk.grib"``: ERA5 temperature data over the UK + * ``"eraint_uvz"``: data from ERA-Interim reanalysis, monthly averages of upper level data + Parameters ---------- name : str - Name of the file containing the dataset. If no suffix is given, assumed - to be netCDF ('.nc' is appended) + Name of the file containing the dataset. e.g. 'air_temperature' - cache_dir : str, optional + cache_dir : path-like, optional The directory in which to search for and write cached data. cache : bool, optional If True, then cache data locally for use on subsequent calls - github_url : str - Github repository where the data is stored - branch : str - The git branch to download from - kws : dict, optional + **kws : dict, optional Passed to xarray.open_dataset See Also -------- xarray.open_dataset - """ - root, ext = _os.path.splitext(name) - if not ext: - ext = ".nc" - fullname = root + ext - longdir = _os.path.expanduser(cache_dir) - localfile = _os.sep.join((longdir, fullname)) - md5name = fullname + ".md5" - md5file = _os.sep.join((longdir, md5name)) - - if not _os.path.exists(localfile): - - # This will always leave this directory on disk. - # May want to add an option to remove it. - if not _os.path.isdir(longdir): - _os.mkdir(longdir) - - url = "/".join((github_url, "raw", branch, fullname)) - urlretrieve(url, localfile) - url = "/".join((github_url, "raw", branch, md5name)) - urlretrieve(url, md5file) - - localmd5 = file_md5_checksum(localfile) - with open(md5file) as f: - remotemd5 = f.read() - if localmd5 != remotemd5: - _os.remove(localfile) - msg = """ - MD5 checksum does not match, try downloading dataset again. - """ - raise OSError(msg) - - ds = _open_dataset(localfile, **kws) - + try: + import pooch + except ImportError as e: + raise ImportError( + "tutorial.open_dataset depends on pooch to download and manage datasets." + " To proceed please install pooch." + ) from e + + logger = pooch.get_logger() + logger.setLevel("WARNING") + + cache_dir = _construct_cache_dir(cache_dir) + if name in external_urls: + url = external_urls[name] + else: + path = pathlib.Path(name) + if not path.suffix: + # process the name + default_extension = ".nc" + if engine is None: + _check_netcdf_engine_installed(name) + path = path.with_suffix(default_extension) + elif path.suffix == ".grib": + if engine is None: + engine = "cfgrib" + + url = f"{base_url}/raw/{version}/{path.name}" + + # retrieve the file + filepath = pooch.retrieve(url=url, known_hash=None, path=cache_dir) + ds = _open_dataset(filepath, engine=engine, **kws) if not cache: ds = ds.load() - _os.remove(localfile) + pathlib.Path(filepath).unlink() return ds +def open_rasterio( + name, + engine=None, + cache=True, + cache_dir=None, + **kws, +): + """ + Open a rasterio dataset from the online repository (requires internet). + + If a local copy is found then always use that to avoid network traffic. + + Available datasets: + + * ``"RGB.byte"``: TIFF file derived from USGS Landsat 7 ETM imagery. + * ``"shade"``: TIFF file derived from from USGS SRTM 90 data + + ``RGB.byte`` and ``shade`` are downloaded from the ``rasterio`` repository [1]_. + + Parameters + ---------- + name : str + Name of the file containing the dataset. + e.g. 'RGB.byte' + cache_dir : path-like, optional + The directory in which to search for and write cached data. + cache : bool, optional + If True, then cache data locally for use on subsequent calls + **kws : dict, optional + Passed to xarray.open_rasterio + + See Also + -------- + xarray.open_rasterio + + References + ---------- + .. [1] https://github.com/mapbox/rasterio + """ + try: + import pooch + except ImportError as e: + raise ImportError( + "tutorial.open_rasterio depends on pooch to download and manage datasets." + " To proceed please install pooch." + ) from e + + logger = pooch.get_logger() + logger.setLevel("WARNING") + + cache_dir = _construct_cache_dir(cache_dir) + url = external_rasterio_urls.get(name) + if url is None: + raise ValueError(f"unknown rasterio dataset: {name}") + + # retrieve the file + filepath = pooch.retrieve(url=url, known_hash=None, path=cache_dir) + arr = _open_rasterio(filepath, **kws) + if not cache: + arr = arr.load() + pathlib.Path(filepath).unlink() + + return arr + + def load_dataset(*args, **kwargs): """ Open, load into memory, and close a dataset from the online repository diff --git a/xarray/ufuncs.py b/xarray/ufuncs.py index 8ab2b7cfe31..bf80dcf68cd 100644 --- a/xarray/ufuncs.py +++ b/xarray/ufuncs.py @@ -27,6 +27,7 @@ _xarray_types = (_Variable, _DataArray, _Dataset, _GroupBy) _dispatch_order = (_np.ndarray, _dask_array_type) + _xarray_types +_UNDEFINED = object() def _dispatch_priority(obj): @@ -53,28 +54,30 @@ def __call__(self, *args, **kwargs): ) new_args = args - f = _dask_or_eager_func(self._name, array_args=slice(len(args))) + res = _UNDEFINED if len(args) > 2 or len(args) == 0: raise TypeError( "cannot handle {} arguments for {!r}".format(len(args), self._name) ) elif len(args) == 1: if isinstance(args[0], _xarray_types): - f = args[0]._unary_op(self) + res = args[0]._unary_op(self) else: # len(args) = 2 p1, p2 = map(_dispatch_priority, args) if p1 >= p2: if isinstance(args[0], _xarray_types): - f = args[0]._binary_op(self) + res = args[0]._binary_op(args[1], self) else: if isinstance(args[1], _xarray_types): - f = args[1]._binary_op(self, reflexive=True) + res = args[1]._binary_op(args[0], self, reflexive=True) new_args = tuple(reversed(args)) - res = f(*new_args, **kwargs) + + if res is _UNDEFINED: + f = _dask_or_eager_func(self._name, array_args=slice(len(args))) + res = f(*new_args, **kwargs) if res is NotImplemented: raise TypeError( - "%r not implemented for types (%r, %r)" - % (self._name, type(args[0]), type(args[1])) + f"{self._name!r} not implemented for types ({type(args[0])!r}, {type(args[1])!r})" ) return res @@ -123,11 +126,11 @@ def _create_op(name): doc = _remove_unused_reference_labels(_skip_signature(_dedent(doc), name)) func.__doc__ = ( - "xarray specific variant of numpy.%s. Handles " + f"xarray specific variant of numpy.{name}. Handles " "xarray.Dataset, xarray.DataArray, xarray.Variable, " "numpy.ndarray and dask.array.Array objects with " "automatic dispatching.\n\n" - "Documentation from numpy:\n\n%s" % (name, doc) + f"Documentation from numpy:\n\n{doc}" ) return func diff --git a/xarray/util/generate_ops.py b/xarray/util/generate_ops.py new file mode 100644 index 00000000000..b6b7f8cbac7 --- /dev/null +++ b/xarray/util/generate_ops.py @@ -0,0 +1,259 @@ +"""Generate module and stub file for arithmetic operators of various xarray classes. + +For internal xarray development use only. + +Usage: + python xarray/util/generate_ops.py --module > xarray/core/_typed_ops.py + python xarray/util/generate_ops.py --stubs > xarray/core/_typed_ops.pyi + +""" +# Note: the comments in https://github.com/pydata/xarray/pull/4904 provide some +# background to some of the design choices made here. + +import sys + +BINOPS_EQNE = (("__eq__", "nputils.array_eq"), ("__ne__", "nputils.array_ne")) +BINOPS_CMP = ( + ("__lt__", "operator.lt"), + ("__le__", "operator.le"), + ("__gt__", "operator.gt"), + ("__ge__", "operator.ge"), +) +BINOPS_NUM = ( + ("__add__", "operator.add"), + ("__sub__", "operator.sub"), + ("__mul__", "operator.mul"), + ("__pow__", "operator.pow"), + ("__truediv__", "operator.truediv"), + ("__floordiv__", "operator.floordiv"), + ("__mod__", "operator.mod"), + ("__and__", "operator.and_"), + ("__xor__", "operator.xor"), + ("__or__", "operator.or_"), +) +BINOPS_REFLEXIVE = ( + ("__radd__", "operator.add"), + ("__rsub__", "operator.sub"), + ("__rmul__", "operator.mul"), + ("__rpow__", "operator.pow"), + ("__rtruediv__", "operator.truediv"), + ("__rfloordiv__", "operator.floordiv"), + ("__rmod__", "operator.mod"), + ("__rand__", "operator.and_"), + ("__rxor__", "operator.xor"), + ("__ror__", "operator.or_"), +) +BINOPS_INPLACE = ( + ("__iadd__", "operator.iadd"), + ("__isub__", "operator.isub"), + ("__imul__", "operator.imul"), + ("__ipow__", "operator.ipow"), + ("__itruediv__", "operator.itruediv"), + ("__ifloordiv__", "operator.ifloordiv"), + ("__imod__", "operator.imod"), + ("__iand__", "operator.iand"), + ("__ixor__", "operator.ixor"), + ("__ior__", "operator.ior"), +) +UNARY_OPS = ( + ("__neg__", "operator.neg"), + ("__pos__", "operator.pos"), + ("__abs__", "operator.abs"), + ("__invert__", "operator.invert"), +) +# round method and numpy/pandas unary methods which don't modify the data shape, +# so the result should still be wrapped in an Variable/DataArray/Dataset +OTHER_UNARY_METHODS = ( + ("round", "ops.round_"), + ("argsort", "ops.argsort"), + ("conj", "ops.conj"), + ("conjugate", "ops.conjugate"), +) + +template_binop = """ + def {method}(self, other): + return self._binary_op(other, {func})""" +template_reflexive = """ + def {method}(self, other): + return self._binary_op(other, {func}, reflexive=True)""" +template_inplace = """ + def {method}(self, other): + return self._inplace_binary_op(other, {func})""" +template_unary = """ + def {method}(self): + return self._unary_op({func})""" +template_other_unary = """ + def {method}(self, *args, **kwargs): + return self._unary_op({func}, *args, **kwargs)""" +required_method_unary = """ + def _unary_op(self, f, *args, **kwargs): + raise NotImplementedError""" +required_method_binary = """ + def _binary_op(self, other, f, reflexive=False): + raise NotImplementedError""" +required_method_inplace = """ + def _inplace_binary_op(self, other, f): + raise NotImplementedError""" + +# For some methods we override return type `bool` defined by base class `object`. +OVERRIDE_TYPESHED = {"override": " # type: ignore[override]"} +NO_OVERRIDE = {"override": ""} + +# Note: in some of the overloads below the return value in reality is NotImplemented, +# which cannot accurately be expressed with type hints,e.g. Literal[NotImplemented] +# or type(NotImplemented) are not allowed and NoReturn has a different meaning. +# In such cases we are lending the type checkers a hand by specifying the return type +# of the corresponding reflexive method on `other` which will be called instead. +stub_ds = """\ + def {method}(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...{override}""" +stub_da = """\ + @overload{override} + def {method}(self, other: T_Dataset) -> T_Dataset: ... + @overload + def {method}(self, other: "DatasetGroupBy") -> "Dataset": ... # type: ignore[misc] + @overload + def {method}(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...""" +stub_var = """\ + @overload{override} + def {method}(self, other: T_Dataset) -> T_Dataset: ... + @overload + def {method}(self, other: T_DataArray) -> T_DataArray: ... + @overload + def {method}(self: T_Variable, other: VarCompatible) -> T_Variable: ...""" +stub_dsgb = """\ + @overload{override} + def {method}(self, other: T_Dataset) -> T_Dataset: ... + @overload + def {method}(self, other: "DataArray") -> "Dataset": ... # type: ignore[misc] + @overload + def {method}(self, other: GroupByIncompatible) -> NoReturn: ...""" +stub_dagb = """\ + @overload{override} + def {method}(self, other: T_Dataset) -> T_Dataset: ... + @overload + def {method}(self, other: T_DataArray) -> T_DataArray: ... + @overload + def {method}(self, other: GroupByIncompatible) -> NoReturn: ...""" +stub_unary = """\ + def {method}(self: {self_type}) -> {self_type}: ...""" +stub_other_unary = """\ + def {method}(self: {self_type}, *args, **kwargs) -> {self_type}: ...""" +stub_required_unary = """\ + def _unary_op(self, f, *args, **kwargs): ...""" +stub_required_binary = """\ + def _binary_op(self, other, f, reflexive=...): ...""" +stub_required_inplace = """\ + def _inplace_binary_op(self, other, f): ...""" + + +def unops(self_type): + extra_context = {"self_type": self_type} + return [ + ([(None, None)], required_method_unary, stub_required_unary, {}), + (UNARY_OPS, template_unary, stub_unary, extra_context), + (OTHER_UNARY_METHODS, template_other_unary, stub_other_unary, extra_context), + ] + + +def binops(stub=""): + return [ + ([(None, None)], required_method_binary, stub_required_binary, {}), + (BINOPS_NUM + BINOPS_CMP, template_binop, stub, NO_OVERRIDE), + (BINOPS_EQNE, template_binop, stub, OVERRIDE_TYPESHED), + (BINOPS_REFLEXIVE, template_reflexive, stub, NO_OVERRIDE), + ] + + +def inplace(): + return [ + ([(None, None)], required_method_inplace, stub_required_inplace, {}), + (BINOPS_INPLACE, template_inplace, "", {}), + ] + + +ops_info = {} +ops_info["DatasetOpsMixin"] = binops(stub_ds) + inplace() + unops("T_Dataset") +ops_info["DataArrayOpsMixin"] = binops(stub_da) + inplace() + unops("T_DataArray") +ops_info["VariableOpsMixin"] = binops(stub_var) + inplace() + unops("T_Variable") +ops_info["DatasetGroupByOpsMixin"] = binops(stub_dsgb) +ops_info["DataArrayGroupByOpsMixin"] = binops(stub_dagb) + +MODULE_PREAMBLE = '''\ +"""Mixin classes with arithmetic operators.""" +# This file was generated using xarray.util.generate_ops. Do not edit manually. + +import operator + +from . import nputils, ops''' + +STUBFILE_PREAMBLE = '''\ +"""Stub file for mixin classes with arithmetic operators.""" +# This file was generated using xarray.util.generate_ops. Do not edit manually. + +from typing import NoReturn, TypeVar, Union, overload + +import numpy as np + +from .dataarray import DataArray +from .dataset import Dataset +from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy +from .npcompat import ArrayLike +from .variable import Variable + +try: + from dask.array import Array as DaskArray +except ImportError: + DaskArray = np.ndarray + +# DatasetOpsMixin etc. are parent classes of Dataset etc. +T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin") +T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin") +T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin") + +ScalarOrArray = Union[ArrayLike, np.generic, np.ndarray, DaskArray] +DsCompatible = Union[Dataset, DataArray, Variable, GroupBy, ScalarOrArray] +DaCompatible = Union[DataArray, Variable, DataArrayGroupBy, ScalarOrArray] +VarCompatible = Union[Variable, ScalarOrArray] +GroupByIncompatible = Union[Variable, GroupBy]''' + +CLASS_PREAMBLE = """{newline} +class {cls_name}: + __slots__ = ()""" + +COPY_DOCSTRING = """\ + {method}.__doc__ = {func}.__doc__""" + + +def render(ops_info, is_module): + """Render the module or stub file.""" + yield MODULE_PREAMBLE if is_module else STUBFILE_PREAMBLE + + for cls_name, method_blocks in ops_info.items(): + yield CLASS_PREAMBLE.format(cls_name=cls_name, newline="\n" * is_module) + yield from _render_classbody(method_blocks, is_module) + + +def _render_classbody(method_blocks, is_module): + for method_func_pairs, method_template, stub_template, extra in method_blocks: + template = method_template if is_module else stub_template + if template: + for method, func in method_func_pairs: + yield template.format(method=method, func=func, **extra) + + if is_module: + yield "" + for method_func_pairs, *_ in method_blocks: + for method, func in method_func_pairs: + if method and func: + yield COPY_DOCSTRING.format(method=method, func=func) + + +if __name__ == "__main__": + + option = sys.argv[1].lower() if len(sys.argv) == 2 else None + if option not in {"--module", "--stubs"}: + raise SystemExit(f"Usage: {sys.argv[0]} --module | --stubs") + is_module = option == "--module" + + for line in render(ops_info, is_module): + print(line) diff --git a/xarray/util/print_versions.py b/xarray/util/print_versions.py index d643d768093..cd5d425efe2 100755 --- a/xarray/util/print_versions.py +++ b/xarray/util/print_versions.py @@ -42,15 +42,15 @@ def get_sys_info(): [ ("python", sys.version), ("python-bits", struct.calcsize("P") * 8), - ("OS", "%s" % (sysname)), - ("OS-release", "%s" % (release)), - # ("Version", "%s" % (version)), - ("machine", "%s" % (machine)), - ("processor", "%s" % (processor)), - ("byteorder", "%s" % sys.byteorder), - ("LC_ALL", "%s" % os.environ.get("LC_ALL", "None")), - ("LANG", "%s" % os.environ.get("LANG", "None")), - ("LOCALE", "%s.%s" % locale.getlocale()), + ("OS", f"{sysname}"), + ("OS-release", f"{release}"), + # ("Version", f"{version}"), + ("machine", f"{machine}"), + ("processor", f"{processor}"), + ("byteorder", f"{sys.byteorder}"), + ("LC_ALL", f'{os.environ.get("LC_ALL", "None")}'), + ("LANG", f'{os.environ.get("LANG", "None")}'), + ("LOCALE", f"{locale.getlocale()}"), ] ) except Exception: