diff --git a/.github/workflows/docs_build_and_deploy.yml b/.github/workflows/docs_build_and_deploy.yml index fe65f42f..f9e3a5c8 100644 --- a/.github/workflows/docs_build_and_deploy.yml +++ b/.github/workflows/docs_build_and_deploy.yml @@ -30,6 +30,7 @@ jobs: - uses: neuroinformatics-unit/actions/build_sphinx_docs@main with: python-version: 3.11 + use-make: true deploy_sphinx_docs: name: Deploy Sphinx Docs @@ -42,3 +43,4 @@ jobs: - uses: neuroinformatics-unit/actions/deploy_sphinx_docs@main with: secret_input: ${{ secrets.GITHUB_TOKEN }} + use-make: true diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 5ce4de5a..4b7b4c11 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -28,16 +28,16 @@ jobs: strategy: matrix: # Run all supported Python versions on linux - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.10", "3.11", "3.12"] os: [ubuntu-latest] # Include 1 Intel macos (13) and 1 M1 macos (latest) and 1 Windows run include: - os: macos-13 - python-version: "3.10" + python-version: "3.11" - os: macos-latest - python-version: "3.10" + python-version: "3.11" - os: windows-latest - python-version: "3.10" + python-version: "3.11" steps: - name: Cache Test Data @@ -52,7 +52,7 @@ jobs: with: python-version: ${{ matrix.python-version }} auto-update-conda: true - channels: conda-forge + channels: conda-forge,nodefaults activate-environment: movement-env - uses: neuroinformatics-unit/actions/test@v2 with: diff --git a/.gitignore b/.gitignore index 19c88937..6611c73a 100644 --- a/.gitignore +++ b/.gitignore @@ -60,6 +60,8 @@ instance/ docs/build/ docs/source/examples/ docs/source/api/ +docs/source/api_index.rst +docs/source/snippets/admonitions.md sg_execution_times.rst # MkDocs documentation diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3266baa3..6aafcddd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ ci: autoupdate_schedule: monthly repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: check-added-large-files - id: check-docstring-first @@ -29,12 +29,12 @@ repos: - id: rst-directive-colons - id: rst-inline-touching-normal - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.3 + rev: v0.7.2 hooks: - id: ruff - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.10.0 + rev: v1.13.0 hooks: - id: mypy additional_dependencies: @@ -45,14 +45,14 @@ repos: - types-PyYAML - types-requests - repo: https://github.com/mgedmin/check-manifest - rev: "0.49" + rev: "0.50" hooks: - id: check-manifest args: [--no-build-isolation] additional_dependencies: [setuptools-scm] - repo: https://github.com/codespell-project/codespell # Configuration for codespell is in pyproject.toml - rev: v2.2.6 + rev: v2.3.0 hooks: - id: codespell additional_dependencies: diff --git a/CITATION.CFF b/CITATION.CFF new file mode 100644 index 00000000..cc243088 --- /dev/null +++ b/CITATION.CFF @@ -0,0 +1,43 @@ +cff-version: 1.2.0 +title: movement +message: >- + If you use movement in your work, please cite the following Zenodo DOI. +type: software +authors: + - given-names: Nikoloz + family-names: Sirmpilatze + orcid: 'https://orcid.org/0000-0003-1778-2427' + email: niko.sirbiladze@gmail.com + - given-names: Chang Huan + family-names: Lo + - given-names: Sofía + family-names: Miñano + - given-names: Brandon D. + family-names: Peri + - given-names: Dhruv + family-names: Sharma + - given-names: Laura + family-names: Porta + - given-names: Iván + family-names: Varela + - given-names: Adam L. + family-names: Tyson + email: code@adamltyson.com +identifiers: + - type: doi + value: 10.5281/zenodo.12755724 + description: 'A collection of archived snapshots of movement on Zenodo.' +repository-code: 'https://github.com/neuroinformatics-unit/movement' +url: 'https://movement.neuroinformatics.dev/' +abstract: >- + Python tools for analysing body movements across space and time. +keywords: + - behavior + - behaviour + - kinematics + - neuroscience + - animal + - motion + - tracking + - pose +license: BSD-3-Clause diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 384638c4..88f5ecd0 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -12,15 +12,12 @@ development environment for movement. In the following we assume you have First, create and activate a `conda` environment with some prerequisites: ```sh -conda create -n movement-dev -c conda-forge python=3.10 pytables +conda create -n movement-dev -c conda-forge python=3.11 pytables conda activate movement-dev ``` -The above method ensures that you will get packages that often can't be -installed via `pip`, including [hdf5](https://www.hdfgroup.org/solutions/hdf5/). - -To install movement for development, clone the GitHub repository, -and then run from inside the repository: +To install movement for development, clone the [GitHub repository](movement-github:), +and then run from within the repository: ```sh pip install -e .[dev] # works on most shells @@ -162,24 +159,29 @@ The version number is automatically determined from the latest tag on the _main_ The documentation is hosted via [GitHub pages](https://pages.github.com/) at [movement.neuroinformatics.dev](target-movement). Its source files are located in the `docs` folder of this repository. -They are written in either [reStructuredText](https://docutils.sourceforge.io/rst.html) or -[markdown](myst-parser:syntax/typography.html). +They are written in either [Markdown](myst-parser:syntax/typography.html) +or [reStructuredText](https://docutils.sourceforge.io/rst.html). The `index.md` file corresponds to the homepage of the documentation website. -Other `.rst` or `.md` files are linked to the homepage via the `toctree` directive. +Other `.md` or `.rst` files are linked to the homepage via the `toctree` directive. -We use [Sphinx](https://www.sphinx-doc.org/en/master/) and the -[PyData Sphinx Theme](https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html) +We use [Sphinx](sphinx-doc:) and the [PyData Sphinx Theme](https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html) to build the source files into HTML output. This is handled by a GitHub actions workflow (`.github/workflows/docs_build_and_deploy.yml`). The build job is triggered on each PR, ensuring that the documentation build is not broken by new changes. The deployment job is only triggered whenever a tag is pushed to the _main_ branch, ensuring that the documentation is published in sync with each PyPI release. + ### Editing the documentation -To edit the documentation, first clone the repository, and install movement in a +To edit the documentation, first clone the repository, and install `movement` in a [development environment](#creating-a-development-environment). +Then, install a few additional dependencies in your development environment to be able to build the documentation locally. To do this, run the following command from the root of the repository: +```sh +pip install -r ./docs/requirements.txt +``` + Now create a new branch, edit the documentation source files (`.md` or `.rst` in the `docs` folder), and commit your changes. Submit your documentation changes via a pull request, following the [same guidelines as for code changes](#pull-requests). @@ -199,34 +201,22 @@ existing_file my_new_file ``` -#### Adding external links -If you are adding references to an external link (e.g. `https://github.com/neuroinformatics-unit/movement/issues/1`) in a `.md` file, you will need to check if a matching URL scheme (e.g. `https://github.com/neuroinformatics-unit/movement/`) is defined in `myst_url_schemes` in `docs/source/conf.py`. If it is, the following `[](scheme:loc)` syntax will be converted to the [full URL](movement-github:issues/1) during the build process: +#### Linking to external URLs +If you are adding references to an external URL (e.g. `https://github.com/neuroinformatics-unit/movement/issues/1`) in a `.md` file, you will need to check if a matching URL scheme (e.g. `https://github.com/neuroinformatics-unit/movement/`) is defined in `myst_url_schemes` in `docs/source/conf.py`. If it is, the following `[](scheme:loc)` syntax will be converted to the [full URL](movement-github:issues/1) during the build process: ```markdown [link text](movement-github:issues/1) ``` -If it is not yet defined and you have multiple external links pointing to the same base URL, you will need to [add the URL scheme](myst-parser:syntax/cross-referencing.html#customising-external-url-resolution) to `myst_url_schemes` in `docs/source/conf.py`. - +If it is not yet defined and you have multiple external URLs pointing to the same base URL, you will need to [add the URL scheme](myst-parser:syntax/cross-referencing.html#customising-external-url-resolution) to `myst_url_schemes` in `docs/source/conf.py`. ### Updating the API reference -If your PR introduces new public-facing functions, classes, or methods, -make sure to add them to the `docs/source/api_index.rst` page, so that they are -included in the [API reference](target-api), -e.g.: - -```rst -My new module --------------- -.. currentmodule:: movement.new_module -.. autosummary:: - :toctree: api - - new_function - NewClass -``` +The [API reference](target-api) is auto-generated by the `docs/make_api_index.py` script, and the [sphinx-autodoc](sphinx-doc:extensions/autodoc.html) and [sphinx-autosummary](sphinx-doc:extensions/autosummary.html) extensions. +The script generates the `docs/source/api_index.rst` file containing the list of modules to be included in the [API reference](target-api). +The plugins then generate the API reference pages for each module listed in `api_index.rst`, based on the docstrings in the source code. +So make sure that all your public functions/classes/methods have valid docstrings following the [numpydoc](https://numpydoc.readthedocs.io/en/latest/format.html) style. +Our `pre-commit` hooks include some checks (`ruff` rules) that ensure the docstrings are formatted consistently. -For this to work, your functions/classes/methods will need to have docstrings -that follow the [numpydoc](https://numpydoc.readthedocs.io/en/latest/format.html) style. +If your PR introduces new modules that should *not* be documented in the [API reference](target-api), or if there are changes to existing modules that necessitate their removal from the documentation, make sure to update the `exclude_modules` list within the `docs/make_api_index.py` script to reflect these exclusions. ### Updating the examples We use [sphinx-gallery](sphinx-gallery:) @@ -235,7 +225,7 @@ To add new examples, you will need to create a new `.py` file in `examples/`. The file should be structured as specified in the relevant [sphinx-gallery documentation](sphinx-gallery:syntax). -We are using sphinx-gallery's [integration with binder](https://sphinx-gallery.github.io/stable/configuration.html#binder-links) +We are using sphinx-gallery's [integration with binder](sphinx-gallery:configuration#binder-links) to provide interactive versions of the examples. If your examples rely on packages that are not among movement's dependencies, you will need to add them to the `docs/source/environment.yml` file. @@ -243,29 +233,85 @@ That file is used by binder to create the conda environment in which the examples are run. See the relevant section of the [binder documentation](https://mybinder.readthedocs.io/en/latest/using/config_files.html). +### Cross-referencing Python objects +:::{note} +Docstrings in the `.py` files for the [API reference](target-api) and the [examples](target-examples) are converted into `.rst` files, so these should use reStructuredText syntax. +::: + +#### Internal references +::::{tab-set} +:::{tab-item} Markdown +For referencing movement objects in `.md` files, use the `` {role}`target` `` syntax with the appropriate [Python object role](sphinx-doc:domains/python.html#cross-referencing-python-objects). + +For example, to reference the {mod}`movement.io.load_poses` module, use: +```markdown +{mod}`movement.io.load_poses` +``` +::: +:::{tab-item} RestructuredText +For referencing movement objects in `.rst` files, use the `` :role:`target` `` syntax with the appropriate [Python object role](sphinx-doc:domains/python.html#cross-referencing-python-objects). + +For example, to reference the {mod}`movement.io.load_poses` module, use: +```rst +:mod:`movement.io.load_poses` +``` +::: +:::: + +#### External references +For referencing external Python objects using [intersphinx](sphinx-doc:extensions/intersphinx.html), +ensure the mapping between module names and their documentation URLs is defined in [`intersphinx_mapping`](sphinx-doc:extensions/intersphinx.html#confval-intersphinx_mapping) in `docs/source/conf.py`. +Once the module is included in the mapping, use the same syntax as for [internal references](#internal-references). + +::::{tab-set} +:::{tab-item} Markdown +For example, to reference the {meth}`xarray.Dataset.update` method, use: +```markdown +{meth}`xarray.Dataset.update` +``` +::: + +:::{tab-item} RestructuredText +For example, to reference the {meth}`xarray.Dataset.update` method, use: +```rst +:meth:`xarray.Dataset.update` +``` +::: +:::: + + ### Building the documentation locally -We recommend that you build and view the documentation website locally, before you push it. -To do so, first install the requirements for building the documentation: +We recommend that you build and view the documentation website locally, before you push your proposed changes. + +First, ensure your development environment with the required dependencies is active (see [Editing the documentation](#editing-the-documentation) for details on how to create it). Then, navigate to the `docs/` directory: ```sh -pip install -r docs/requirements.txt +cd docs ``` +All subsequent commands should be run from this directory. + +To build the documentation, run: -Then, from the root of the repository, run: ```sh -sphinx-build docs/source docs/build +make html ``` +The local build can be viewed by opening `docs/build/html/index.html` in a browser. -You can view the local build by opening `docs/build/index.html` in a browser. -To refresh the documentation, after making changes, remove the `docs/build` folder and re-run the above command: +To re-build the documentation after making changes, +we recommend removing existing build files first. +The following command will remove all generated files in `docs/`, +including the auto-generated files `source/api_index.rst` and +`source/snippets/admonitions.md`, as well as all files in + `build/`, `source/api/`, and `source/examples/`. + It will then re-build the documentation: ```sh -rm -rf docs/build && sphinx-build docs/source docs/build +make clean html ``` To check that external links are correctly resolved, run: ```sh -sphinx-build docs/source docs/build -b linkcheck +make linkcheck ``` If the linkcheck step incorrectly marks links with valid anchors as broken, you can skip checking the anchors in specific links by adding the URLs to `linkcheck_anchors_ignore_for_url` in `docs/source/conf.py`, e.g.: @@ -279,6 +325,14 @@ linkcheck_anchors_ignore_for_url = [ ] ``` +:::{tip} +The `make` commands can be combined to run multiple tasks sequentially. +For example, to re-build the documentation and check the links, run: +```sh +make clean html linkcheck +``` +::: + ## Sample data We maintain some sample datasets to be used for testing, examples and tutorials on an @@ -289,12 +343,12 @@ GIN has a GitHub-like interface and git-like [CLI](gin:G-Node/Info/wiki/GIN+CLI+Setup#quickstart) functionalities. Currently, the data repository contains sample pose estimation data files -stored in the `poses` folder. For some of these files, we also host +stored in the `poses` folder, and tracked bounding boxes data files under the `bboxes` folder. For some of these files, we also host the associated video file (in the `videos` folder) and/or a single video frame (in the `frames`) folder. These can be used to develop and -test visualisations, e.g. overlaying pose data on video frames. +test visualisations, e.g. to overlay the data on video frames. The `metadata.yaml` file holds metadata for each sample dataset, -including information on data provenance as well as the mapping between pose data files and related +including information on data provenance as well as the mapping between data files and related video/frame files. ### Fetching data @@ -307,9 +361,9 @@ The relevant functionality is implemented in the `movement.sample_data.py` modul The most important parts of this module are: 1. The `SAMPLE_DATA` download manager object. -2. The `list_datasets()` function, which returns a list of the available pose datasets (file names of the pose data files). -3. The `fetch_dataset_paths()` function, which returns a dictionary containing local paths to the files associated with a particular sample dataset: `poses`, `frame`, `video`. If the relevant files are not already cached locally, they will be downloaded. -4. The `fetch_dataset()` function, which downloads the files associated with a given sample dataset (same as `fetch_dataset_paths()`) and additionally loads the pose data into `movement`, returning an `xarray.Dataset` object. The local paths to the associated video and frame files are stored as dataset attributes, with names `video_path` and `frame_path`, respectively. +2. The `list_datasets()` function, which returns a list of the available poses and bounding boxes datasets (file names of the data files). +3. The `fetch_dataset_paths()` function, which returns a dictionary containing local paths to the files associated with a particular sample dataset: `poses` or `bboxes`, `frame`, `video`. If the relevant files are not already cached locally, they will be downloaded. +4. The `fetch_dataset()` function, which downloads the files associated with a given sample dataset (same as `fetch_dataset_paths()`) and additionally loads the pose or bounding box data into movement, returning an `xarray.Dataset` object. If available, the local paths to the associated video and frame files are stored as dataset attributes, with names `video_path` and `frame_path`, respectively. By default, the downloaded files are stored in the `~/.movement/data` folder. This can be changed by setting the `DATA_DIR` variable in the `movement.sample_data.py` module. @@ -322,17 +376,43 @@ To add a new file, you will need to: 2. Ask to be added as a collaborator on the [movement data repository](gin:neuroinformatics/movement-test-data) (if not already) 3. Download the [GIN CLI](gin:G-Node/Info/wiki/GIN+CLI+Setup#quickstart) and set it up with your GIN credentials, by running `gin login` in a terminal. 4. Clone the movement data repository to your local machine, by running `gin get neuroinformatics/movement-test-data` in a terminal. -5. Add your new files to the `poses`, `videos`, and/or `frames` folders as appropriate. Follow the existing file naming conventions as closely as possible. -6. Determine the sha256 checksum hash of each new file by running `sha256sum ` in a terminal. For convenience, we've included a `get_sha256_hashes.py` script in the [movement data repository](gin:neuroinformatics/movement-test-data). If you run this from the root of the data repository, within a Python environment with `movement` installed, it will calculate the sha256 hashes for all files in the `poses`, `videos`, and `frames` folders and write them to files named `poses_hashes.txt`, `videos_hashes.txt`, and `frames_hashes.txt`, respectively. +5. Add your new files to the `poses`, `bboxes`, `videos` and/or `frames` folders as appropriate. Follow the existing file naming conventions as closely as possible. +6. Determine the sha256 checksum hash of each new file. You can do this in a terminal by running: + ::::{tab-set} + :::{tab-item} Ubuntu + ```bash + sha256sum + ``` + ::: + + :::{tab-item} MacOS + ```bash + shasum -a 256 + ``` + ::: + + :::{tab-item} Windows + ```bash + certutil -hashfile SHA256 + ``` + ::: + :::: + For convenience, we've included a `get_sha256_hashes.py` script in the [movement data repository](gin:neuroinformatics/movement-test-data). If you run this from the root of the data repository, within a Python environment with movement installed, it will calculate the sha256 hashes for all files in the `poses`, `bboxes`, `videos` and `frames` folders and write them to files named `poses_hashes.txt`, `bboxes_hashes.txt`, `videos_hashes.txt`, and `frames_hashes.txt` respectively. + 7. Add metadata for your new files to `metadata.yaml`, including their sha256 hashes you've calculated. See the example entry below for guidance. + 8. Commit a specific file with `gin commit -m `, or `gin commit -m .` to commit all changes. + 9. Upload the committed changes to the GIN repository by running `gin upload`. Latest changes to the repository can be pulled via `gin download`. `gin sync` will synchronise the latest changes bidirectionally. + + ### `metadata.yaml` example entry ```yaml "SLEAP_three-mice_Aeon_proofread.analysis.h5": sha256sum: "82ebd281c406a61536092863bc51d1a5c7c10316275119f7daf01c1ff33eac2a" source_software: "SLEAP" + type: "poses" # "poses" or "bboxes" depending on the type of tracked data fps: 50 species: "mouse" number_of_individuals: 3 diff --git a/MANIFEST.in b/MANIFEST.in index ff091745..0fbedce8 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,6 @@ include LICENSE include *.md +include CITATION.CFF exclude .pre-commit-config.yaml exclude .cruft.json diff --git a/README.md b/README.md index 847ba6ed..ab66a9b8 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,11 @@ +[![Python Version](https://img.shields.io/pypi/pyversions/movement.svg)](https://pypi.org/project/movement) [![License](https://img.shields.io/badge/License-BSD_3--Clause-orange.svg)](https://opensource.org/licenses/BSD-3-Clause) ![CI](https://img.shields.io/github/actions/workflow/status/neuroinformatics-unit/movement/test_and_deploy.yml?label=CI) [![codecov](https://codecov.io/gh/neuroinformatics-unit/movement/branch/main/graph/badge.svg?token=P8CCH3TI8K)](https://codecov.io/gh/neuroinformatics-unit/movement) [![Code style: Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/format.json)](https://github.com/astral-sh/ruff) [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit) [![project chat](https://img.shields.io/badge/zulip-join_chat-brightgreen.svg)](https://neuroinformatics.zulipchat.com/#narrow/stream/406001-Movement/topic/Welcome!) +[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.12755724.svg)](https://zenodo.org/doi/10.5281/zenodo.12755724) # movement @@ -14,17 +16,12 @@ A Python toolbox for analysing body movements across space and time, to aid the ## Quick install -First, create and activate a conda environment with the required dependencies: +Create and activate a conda environment with movement installed: ``` -conda create -n movement-env -c conda-forge python=3.10 pytables +conda create -n movement-env -c conda-forge movement conda activate movement-env ``` -Then install the `movement` package: -``` -pip install movement -``` - > [!Note] > Read the [documentation](https://movement.neuroinformatics.dev) for more information, including [full installation instructions](https://movement.neuroinformatics.dev/getting_started/installation.html) and [examples](https://movement.neuroinformatics.dev/examples/index.html). @@ -37,10 +34,19 @@ We aim to support a range of pose estimation packages, along with 2D or 3D track Find out more on our [mission and scope](https://movement.neuroinformatics.dev/community/mission-scope.html) statement and our [roadmap](https://movement.neuroinformatics.dev/community/roadmaps.html). -## Status + + > [!Warning] > 🏗️ The package is currently in early development and the interface is subject to change. Feel free to play around and provide feedback. +> [!Tip] +> If you prefer analysing your data in R, we recommend checking out the +> [animovement](https://www.roald-arboel.com/animovement/) toolbox, which is similar in scope. +> We are working together with its developer +> to gradually converge on common data standards and workflows. + + + ## Join the movement Contributions to movement are absolutely encouraged, whether to fix a bug, develop a new feature, or improve the documentation. @@ -48,6 +54,12 @@ To help you get started, we have prepared a detailed [contributing guide](https: You are welcome to chat with the team on [zulip](https://neuroinformatics.zulipchat.com/#narrow/stream/406001-Movement). You can also [open an issue](https://github.com/neuroinformatics-unit/movement/issues) to report a bug or request a new feature. +## Citation + +If you use movement in your work, please cite the following Zenodo DOI: + +> Nikoloz Sirmpilatze, Chang Huan Lo, Sofía Miñano, Brandon D. Peri, Dhruv Sharma, Laura Porta, Iván Varela & Adam L. Tyson (2024). neuroinformatics-unit/movement. Zenodo. https://zenodo.org/doi/10.5281/zenodo.12755724 + ## License ⚖️ [BSD 3-Clause](./LICENSE) diff --git a/docs/Makefile b/docs/Makefile index d0c3cbf1..df622291 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -3,7 +3,8 @@ # You can set these variables from the command line, and also # from the environment for the first two. -SPHINXOPTS ?= +# -W: if there are warnings, treat them as errors and exit with status 1. +SPHINXOPTS ?= -W SPHINXBUILD ?= sphinx-build SOURCEDIR = source BUILDDIR = build @@ -14,7 +15,24 @@ help: .PHONY: help Makefile +# Generate the API index file +api_index.rst: + python make_api_index.py + +# Generate the snippets/admonitions.md file +# by converting the admonitions in the repo's README.md to MyST format +admonitions.md: + python convert_admonitions.py + +# Remove all generated files +clean: + rm -rf ./build + rm -f ./source/api_index.rst + rm -rf ./source/api + rm -rf ./source/examples + rm -rf ./source/snippets/admonitions.md + # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile +%: Makefile api_index.rst admonitions.md @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/convert_admonitions.py b/docs/convert_admonitions.py new file mode 100644 index 00000000..dc3dc27c --- /dev/null +++ b/docs/convert_admonitions.py @@ -0,0 +1,87 @@ +"""Convert admonitions GitHub Flavored Markdown (GFM) to MyST Markdown.""" + +import re +from pathlib import Path + +# Valid admonition types supported by both GFM and MyST (case-insensitive) +VALID_TYPES = {"note", "tip", "important", "warning", "caution"} + + +def convert_gfm_admonitions_to_myst_md( + input_path: Path, output_path: Path, exclude: set[str] | None = None +): + """Convert admonitions from GitHub Flavored Markdown to MyST. + + Extracts GitHub Flavored Markdown admonitions from the input file and + writes them to the output file as MyST Markdown admonitions. + The original admonition type and order are preserved. + + Parameters + ---------- + input_path : Path + Path to the input file containing GitHub Flavored Markdown. + output_path : Path + Path to the output file to write the MyST Markdown admonitions. + exclude : set[str], optional + Set of admonition types to exclude from conversion (case-insensitive). + Default is None. + + """ + excluded_types = {s.lower() for s in (exclude or set())} + + # Read the input file + gfm_text = input_path.read_text(encoding="utf-8") + + # Regex pattern to match GFM admonitions + pattern = r"(^> \[!(\w+)\]\n(?:^> .*\n?)*)" + matches = re.finditer(pattern, gfm_text, re.MULTILINE) + + # Process matches and collect converted admonitions + admonitions = [] + for match in matches: + adm_myst = _process_match(match, excluded_types) + if adm_myst: + admonitions.append(adm_myst) + + if admonitions: + # Write all admonitions to a single file + output_path.write_text("\n".join(admonitions) + "\n", encoding="utf-8") + print(f"Admonitions written to {output_path}") + else: + print("No GitHub Markdown admonitions found.") + + +def _process_match(match: re.Match, excluded_types: set[str]) -> str | None: + """Process a regex match and return the converted admonition if valid.""" + # Extract the admonition type + adm_type = match.group(2).lower() + if adm_type not in VALID_TYPES or adm_type in excluded_types: + return None + + # Extract the content lines + full_block = match.group(0) + content = "\n".join( + line[2:].strip() + for line in full_block.split("\n") + if line.startswith("> ") and not line.startswith("> [!") + ).strip() + + # Return the converted admonition + return ":::{" + adm_type + "}\n" + content + "\n" + ":::\n" + + +if __name__ == "__main__": + # Path to the README.md file + # (1 level above the current script) + docs_dir = Path(__file__).resolve().parent + readme_path = docs_dir.parent / "README.md" + + # Path to the output file + # (inside the docs/source/snippets directory) + snippets_dir = docs_dir / "source" / "snippets" + target_path = snippets_dir / "admonitions.md" + + # Call the function + convert_gfm_admonitions_to_myst_md( + readme_path, target_path, exclude={"note"} + ) diff --git a/docs/make.bat b/docs/make.bat index dc1312ab..1969d4b3 100644 --- a/docs/make.bat +++ b/docs/make.bat @@ -9,6 +9,7 @@ if "%SPHINXBUILD%" == "" ( ) set SOURCEDIR=source set BUILDDIR=build +set SPHINXOPTS=-W %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( @@ -25,7 +26,27 @@ if errorlevel 9009 ( if "%1" == "" goto help -%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +:process_targets +if "%1" == "clean" ( + echo Removing auto-generated files... + rmdir /S /Q %BUILDDIR% + del /Q %SOURCEDIR%\api_index.rst + rmdir /S /Q %SOURCEDIR%\api\ + rmdir /S /Q %SOURCEDIR%\examples\ + del /Q %SOURCEDIR%\snippets\admonitions.md +) else ( + echo Generating API index... + python make_api_index.py + + echo Converting admonitions... + python convert_admonitions.py + + %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +) + +shift +if not "%1" == "" goto process_targets + goto end :help diff --git a/docs/make_api_index.py b/docs/make_api_index.py new file mode 100644 index 00000000..c850022e --- /dev/null +++ b/docs/make_api_index.py @@ -0,0 +1,37 @@ +"""Generate the API index page for all ``movement`` modules.""" + +import os +from pathlib import Path + +# Modules to exclude from the API index +exclude_modules = ["cli_entrypoint"] + +# Set the current working directory to the directory of this script +script_dir = Path(__file__).resolve().parent +os.chdir(script_dir) + + +def make_api_index(): + """Create a doctree of all ``movement`` modules.""" + doctree = "\n" + api_path = Path("../movement") + for path in sorted(api_path.rglob("*.py")): + if path.name.startswith("_"): + continue + # Convert file path to module name + rel_path = path.relative_to(api_path.parent) + module_name = str(rel_path.with_suffix("")).replace(os.sep, ".") + if rel_path.stem not in exclude_modules: + doctree += f" {module_name}\n" + # Get the header + api_head_path = Path("source") / "_templates" / "api_index_head.rst" + api_head = api_head_path.read_text() + # Write api_index.rst with header + doctree + output_path = Path("source") / "api_index.rst" + with output_path.open("w") as f: + f.write(api_head) + f.write(doctree) + + +if __name__ == "__main__": + make_api_index() diff --git a/docs/requirements.txt b/docs/requirements.txt index 63615d66..0a950f2e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,7 +4,7 @@ myst-parser nbsphinx pydata-sphinx-theme setuptools-scm -sphinx>=7.0 +sphinx sphinx-autodoc-typehints sphinx-design sphinx-gallery diff --git a/docs/source/_static/css/custom.css b/docs/source/_static/css/custom.css index 40e09e63..ce6b5ae3 100644 --- a/docs/source/_static/css/custom.css +++ b/docs/source/_static/css/custom.css @@ -30,3 +30,12 @@ display: flex; flex-wrap: wrap; justify-content: space-between; } + +/* Disable decoration for all but movement backrefs */ +a[class^="sphx-glr-backref-module-"], +a[class^="sphx-glr-backref-type-"] { + text-decoration: none; +} +a[class^="sphx-glr-backref-module-movement"] { + text-decoration: underline; +} diff --git a/docs/source/_static/data_icon.png b/docs/source/_static/data_icon.png new file mode 100644 index 00000000..835121ff Binary files /dev/null and b/docs/source/_static/data_icon.png differ diff --git a/docs/source/_static/dataset_structure.png b/docs/source/_static/dataset_structure.png index 4e6a17d9..13506196 100644 Binary files a/docs/source/_static/dataset_structure.png and b/docs/source/_static/dataset_structure.png differ diff --git a/docs/source/_static/movement_overview.png b/docs/source/_static/movement_overview.png index 8af12daa..33c5927a 100644 Binary files a/docs/source/_static/movement_overview.png and b/docs/source/_static/movement_overview.png differ diff --git a/docs/source/_templates/api_index_head.rst b/docs/source/_templates/api_index_head.rst new file mode 100644 index 00000000..0c1e6b7e --- /dev/null +++ b/docs/source/_templates/api_index_head.rst @@ -0,0 +1,16 @@ +.. + This file is auto-generated. + +.. _target-api: + +API Reference +============= + +Information on specific functions, classes, and methods. + +.. rubric:: Modules + +.. autosummary:: + :toctree: api + :recursive: + :nosignatures: diff --git a/docs/source/_templates/autosummary/class.rst b/docs/source/_templates/autosummary/class.rst new file mode 100644 index 00000000..ab085cb7 --- /dev/null +++ b/docs/source/_templates/autosummary/class.rst @@ -0,0 +1,31 @@ +{{ name | escape | underline}} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + :members: + :show-inheritance: + :inherited-members: + + {% block methods %} + {% set ns = namespace(has_public_methods=false) %} + + {% if methods %} + {% for item in methods %} + {% if not item.startswith('_') %} + {% set ns.has_public_methods = true %} + {% endif %} + {%- endfor %} + {% endif %} + + {% if ns.has_public_methods %} + .. rubric:: {{ _('Methods') }} + + .. autosummary:: + {% for item in methods %} + {% if not item.startswith('_') %} + ~{{ name }}.{{ item }} + {% endif %} + {%- endfor %} + {% endif %} + {% endblock %} diff --git a/docs/source/_templates/autosummary/function.rst b/docs/source/_templates/autosummary/function.rst new file mode 100644 index 00000000..5536fa10 --- /dev/null +++ b/docs/source/_templates/autosummary/function.rst @@ -0,0 +1,5 @@ +{{ name | escape | underline}} + +.. currentmodule:: {{ module }} + +.. auto{{ objtype }}:: {{ objname }} diff --git a/docs/source/_templates/autosummary/module.rst b/docs/source/_templates/autosummary/module.rst new file mode 100644 index 00000000..306ccd36 --- /dev/null +++ b/docs/source/_templates/autosummary/module.rst @@ -0,0 +1,31 @@ +{{ fullname | escape | underline }} + +.. rubric:: Description + +.. automodule:: {{ fullname }} + +.. currentmodule:: {{ fullname }} + +{% if classes %} +.. rubric:: Classes + +.. autosummary:: + :toctree: . + :nosignatures: + {% for class in classes %} + {{ class.split('.')[-1] }} + {% endfor %} + +{% endif %} + +{% if functions %} +.. rubric:: Functions + +.. autosummary:: + :toctree: . + :nosignatures: + {% for function in functions %} + {{ function.split('.')[-1] }} + {% endfor %} + +{% endif %} diff --git a/docs/source/api_index.rst b/docs/source/api_index.rst deleted file mode 100644 index a6715af6..00000000 --- a/docs/source/api_index.rst +++ /dev/null @@ -1,93 +0,0 @@ -.. _target-api: - -API Reference -============= - - -Input/Output ------------- -.. currentmodule:: movement.io.load_poses -.. autosummary:: - :toctree: api - - from_file - from_sleap_file - from_dlc_file - from_dlc_df - from_lp_file - -.. currentmodule:: movement.io.save_poses -.. autosummary:: - :toctree: api - - to_dlc_file - to_dlc_df - to_sleap_analysis_file - to_lp_file - -.. currentmodule:: movement.io.validators -.. autosummary:: - :toctree: api - - ValidFile - ValidHDF5 - ValidDeepLabCutCSV - ValidPosesDataset - -Sample Data ------------ -.. currentmodule:: movement.sample_data -.. autosummary:: - :toctree: api - - list_datasets - fetch_dataset_paths - fetch_dataset - -Filtering ---------- -.. currentmodule:: movement.filtering -.. autosummary:: - :toctree: api - - filter_by_confidence - median_filter - savgol_filter - interpolate_over_time - report_nan_values - - -Analysis ------------ -.. currentmodule:: movement.analysis.kinematics -.. autosummary:: - :toctree: api - - compute_displacement - compute_velocity - compute_acceleration - -.. currentmodule:: movement.utils.vector -.. autosummary:: - :toctree: api - - cart2pol - pol2cart - -MovementDataset ---------------- -.. currentmodule:: movement.move_accessor -.. autosummary:: - :toctree: api - - MovementDataset - -Logging -------- -.. currentmodule:: movement.logging -.. autosummary:: - :toctree: api - - configure_logging - log_error - log_warning diff --git a/docs/source/community/roadmaps.md b/docs/source/community/roadmaps.md index 69b6000e..78b1bd67 100644 --- a/docs/source/community/roadmaps.md +++ b/docs/source/community/roadmaps.md @@ -24,5 +24,5 @@ We plan to release version `v0.1` of movement in early 2024, providing a minimal - [x] Ability to compute velocity and acceleration from pose tracks. - [x] Public website with [documentation](target-movement). - [x] Package released on [PyPI](https://pypi.org/project/movement/). -- [ ] Package released on [conda-forge](https://conda-forge.org/). +- [x] Package released on [conda-forge](https://anaconda.org/conda-forge/movement). - [ ] Ability to visualise pose tracks using [napari](napari:). We aim to represent pose tracks via napari's [Points](napari:howtos/layers/points) and [Tracks](napari:howtos/layers/tracks) layers and overlay them on video frames. diff --git a/docs/source/conf.py b/docs/source/conf.py index d035f9b0..3039162e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -47,6 +47,7 @@ "sphinx_design", "sphinx_gallery.gen_gallery", "sphinx_sitemap", + "sphinx.ext.autosectionlabel", ] # Configure the myst parser to enable cool markdown features @@ -67,7 +68,7 @@ "tasklist", ] # Automatically add anchors to markdown headings -myst_heading_anchors = 3 +myst_heading_anchors = 4 # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] @@ -76,6 +77,9 @@ autosummary_generate = True autodoc_default_flags = ["members", "inherited-members"] +# Prefix section labels with the document name +autosectionlabel_prefix_document = True + # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. @@ -104,7 +108,7 @@ "binderhub_url": "https://mybinder.org", "dependencies": ["environment.yml"], }, - 'remove_config_comments': True, + "remove_config_comments": True, # do not render config params set as # sphinx_gallery_config [= value] } @@ -166,6 +170,11 @@ "https://neuroinformatics.zulipchat.com/", "https://github.com/talmolab/sleap/blob/v1.3.3/sleap/info/write_tracking_h5.py", ] +# A list of regular expressions that match URIs that should not be checked +linkcheck_ignore = [ + "https://pubs.acs.org/doi/*", # Checking dois is forbidden here + "https://opensource.org/license/bsd-3-clause/", # to avoid odd 403 error +] myst_url_schemes = { "http": None, @@ -183,7 +192,14 @@ "napari": "https://napari.org/dev/{{path}}", "setuptools-scm": "https://setuptools-scm.readthedocs.io/en/latest/{{path}}#{{fragment}}", "sleap": "https://sleap.ai/{{path}}#{{fragment}}", - "sphinx-gallery": "https://sphinx-gallery.github.io/stable/{{path}}", + "sphinx-doc": "https://www.sphinx-doc.org/en/master/usage/{{path}}#{{fragment}}", + "sphinx-gallery": "https://sphinx-gallery.github.io/stable/{{path}}#{{fragment}}", "xarray": "https://docs.xarray.dev/en/stable/{{path}}#{{fragment}}", "lp": "https://lightning-pose.readthedocs.io/en/stable/{{path}}#{{fragment}}", + "via": "https://www.robots.ox.ac.uk/~vgg/software/via/{{path}}#{{fragment}}", +} + +intersphinx_mapping = { + "xarray": ("https://docs.xarray.dev/en/stable/", None), + "scipy": ("https://docs.scipy.org/doc/scipy/reference/", None), } diff --git a/docs/source/environment.yml b/docs/source/environment.yml index 00c7d126..b84ac374 100644 --- a/docs/source/environment.yml +++ b/docs/source/environment.yml @@ -3,7 +3,7 @@ channels: - conda-forge dependencies: - - python=3.10 + - python=3.11 - pytables - pip: - movement diff --git a/docs/source/getting_started/index.md b/docs/source/getting_started/index.md index b98a529c..4795a9cd 100644 --- a/docs/source/getting_started/index.md +++ b/docs/source/getting_started/index.md @@ -2,7 +2,7 @@ Start by [installing the package](installation.md). -After that try [loading some of your own predicted poses](target-loading), +After that, try [loading some of your own tracked poses](target-loading-pose-tracks) or [bounding boxes' trajectories](target-loading-bbox-tracks), from one of the [supported formats](target-formats). Alternatively, you can use the [sample data](target-sample-data) provided with the package. diff --git a/docs/source/getting_started/input_output.md b/docs/source/getting_started/input_output.md index b4ad43f9..ccb16a7b 100644 --- a/docs/source/getting_started/input_output.md +++ b/docs/source/getting_started/input_output.md @@ -4,35 +4,40 @@ (target-formats)= ## Supported formats (target-supported-formats)= -`movement` can load pose tracks from various pose estimation frameworks. -Currently, these include: +`movement` supports the analysis of trajectories of keypoints (_pose tracks_) and of bounding boxes' centroids (_bounding boxes' tracks_). + +To analyse pose tracks, `movement` supports loading data from various frameworks: - [DeepLabCut](dlc:) (DLC) - [SLEAP](sleap:) (SLEAP) - [LightingPose](lp:) (LP) -:::{warning} -`movement` only deals with the predicted pose tracks output by these -software packages. It does not support the training or labelling of the data. +To analyse bounding boxes' tracks, `movement` currently supports the [VGG Image Annotator](via:) (VIA) format for [tracks annotation](via:docs/face_track_annotation.html). + +:::{note} +At the moment `movement` only deals with tracked data: either keypoints or bounding boxes whose identities are known from one frame to the next, for a consecutive set of frames. For the pose estimation case, this means it only deals with the predictions output by the software packages above. It currently does not support loading manually labelled data (since this is most often defined over a non-continuous set of frames). ::: -(target-loading)= +Below we explain how you can load pose tracks and bounding boxes' tracks into `movement`, and how you can export a `movement` poses dataset to different file formats. You can also try `movement` out on some [sample data](target-sample-data) +included with the package. + + +(target-loading-pose-tracks)= ## Loading pose tracks -The loading functionalities are provided by the +The pose tracks loading functionalities are provided by the {mod}`movement.io.load_poses` module, which can be imported as follows: ```python from movement.io import load_poses ``` -Depending on the source sofrware, one of the following functions can be used. +To read a pose tracks file into a [movement poses dataset](target-poses-and-bboxes-dataset), we provide specific functions for each of the supported formats. We additionally provide a more general `from_numpy()` method, with which we can build a [movement poses dataset](target-poses-and-bboxes-dataset) from a set of NumPy arrays. ::::{tab-set} :::{tab-item} SLEAP -Load from [SLEAP analysis files](sleap:tutorials/analysis) (.h5, recommended), -or from .slp files (experimental): +To load [SLEAP analysis files](sleap:tutorials/analysis) in .h5 format (recommended): ```python ds = load_poses.from_sleap_file("/path/to/file.analysis.h5", fps=30) @@ -41,9 +46,7 @@ ds = load_poses.from_file( "/path/to/file.analysis.h5", source_software="SLEAP", fps=30 ) ``` - -You can also load from SLEAP .slp files in the same way, but there are caveats -to that approach (see notes in {func}`movement.io.load_poses.from_sleap_file`). +To load [SLEAP analysis files](sleap:tutorials/analysis) in .slp format (experimental, see notes in {func}`movement.io.load_poses.from_sleap_file`): ```python ds = load_poses.from_sleap_file("/path/to/file.predictions.slp", fps=30) @@ -52,7 +55,7 @@ ds = load_poses.from_sleap_file("/path/to/file.predictions.slp", fps=30) :::{tab-item} DeepLabCut -Load from DeepLabCut files (.h5): +To load DeepLabCut files in .h5 format: ```python ds = load_poses.from_dlc_file("/path/to/file.h5", fps=30) @@ -62,8 +65,7 @@ ds = load_poses.from_file( ) ``` -You may also load .csv files -(assuming they are formatted as DeepLabCut expects them): +To load DeepLabCut files in .csv format: ```python ds = load_poses.from_dlc_file("/path/to/file.csv", fps=30) ``` @@ -71,7 +73,7 @@ ds = load_poses.from_dlc_file("/path/to/file.csv", fps=30) :::{tab-item} LightningPose -Load from LightningPose files (.csv): +To load LightningPose files in .csv format: ```python ds = load_poses.from_lp_file("/path/to/file.analysis.csv", fps=30) @@ -82,23 +84,91 @@ ds = load_poses.from_file( ``` ::: +:::{tab-item} From NumPy + +In the example below, we create random position data for two individuals, ``Alice`` and ``Bob``, +with three keypoints each: ``snout``, ``centre``, and ``tail_base``. These keypoints are tracked in 2D space for 100 frames, at 30 fps. The confidence scores are set to 1 for all points. + +```python +import numpy as np + +ds = load_poses.from_numpy( + position_array=np.random.rand((100, 2, 3, 2)), + confidence_array=np.ones((100, 2, 3)), + individual_names=["Alice", "Bob"], + keypoint_names=["snout", "centre", "tail_base"], + fps=30, +) +``` +::: + :::: -The loaded data include the predicted positions for each individual and -keypoint as well as the associated point-wise confidence values, as reported by -the pose estimation software. See the [movement dataset](target-dataset) page -for more information on data structure. +The resulting poses data structure `ds` will include the predicted trajectories for each individual and +keypoint, as well as the associated point-wise confidence values reported by +the pose estimation software. -You can also try `movement` out on some [sample data](target-sample-data) -included with the package. +For more information on the poses data structure, see the [movement poses dataset](target-poses-and-bboxes-dataset) page. + + +(target-loading-bbox-tracks)= +## Loading bounding boxes' tracks +To load bounding boxes' tracks into a [movement bounding boxes dataset](target-poses-and-bboxes-dataset), we need the functions from the +{mod}`movement.io.load_bboxes` module. This module can be imported as: + +```python +from movement.io import load_bboxes +``` + +We currently support loading bounding boxes' tracks in the VGG Image Annotator (VIA) format only. However, like in the poses datasets, we additionally provide a `from_numpy()` method, with which we can build a [movement bounding boxes dataset](target-poses-and-bboxes-dataset) from a set of NumPy arrays. + +::::{tab-set} +:::{tab-item} VGG Image Annotator + +To load a VIA tracks .csv file: +```python +ds = load_bboxes.from_via_tracks_file("path/to/file.csv", fps=30) + +# or equivalently +ds = load_bboxes.from_file( + "path/to/file.csv", + source_software="VIA-tracks", + fps=30, +) +``` +::: + +:::{tab-item} From NumPy + +In the example below, we create random position data for two bounding boxes, ``id_0`` and ``id_1``, +both with the same width (40 pixels) and height (30 pixels). These are tracked in 2D space for 100 frames, which will be numbered in the resulting dataset from 0 to 99. The confidence score for all bounding boxes is set to 0.5. + +```python +import numpy as np + +ds = load_bboxes.from_numpy( + position_array=np.random.rand(100, 2, 2), + shape_array=np.ones((100, 2, 2)) * [40, 30], + confidence_array=np.ones((100, 2)) * 0.5, + individual_names=["id_0", "id_1"] +) +``` +::: + +:::: + +The resulting data structure `ds` will include the centroid trajectories for each tracked bounding box, the boxes' widths and heights, and their associated confidence values if provided. -(target-saving)= +For more information on the bounding boxes data structure, see the [movement bounding boxes dataset](target-poses-and-bboxes-dataset) page. + + +(target-saving-pose-tracks)= ## Saving pose tracks -[movement datasets](target-dataset) can be saved as a variety of +[movement poses datasets](target-poses-and-bboxes-dataset) can be saved in a variety of formats, including DeepLabCut-style files (.h5 or .csv) and [SLEAP-style analysis files](sleap:tutorials/analysis) (.h5). -First import the {mod}`movement.io.save_poses` module: +To export pose tracks from `movement`, first import the {mod}`movement.io.save_poses` module: ```python from movement.io import save_poses @@ -110,7 +180,7 @@ Then, depending on the desired format, use one of the following functions: ::::{tab-item} SLEAP -Save to SLEAP-style analysis files (.h5): +To save as a SLEAP analysis file in .h5 format: ```python save_poses.to_sleap_analysis_file(ds, "/path/to/file.h5") ``` @@ -129,7 +199,7 @@ each attribute and data variable represents, see the ::::{tab-item} DeepLabCut -Save to DeepLabCut-style files (.h5 or .csv): +To save as a DeepLabCut file, in .h5 or .csv format: ```python save_poses.to_dlc_file(ds, "/path/to/file.h5") # preferred format save_poses.to_dlc_file(ds, "/path/to/file.csv") @@ -143,13 +213,13 @@ save the data as separate single-animal DeepLabCut-style files. ::::{tab-item} LightningPose -Save to LightningPose files (.csv). +To save as a LightningPose file in .csv format: ```python save_poses.to_lp_file(ds, "/path/to/file.csv") ``` :::{note} -Because LightningPose saves pose estimation outputs in the same format as single-animal -DeepLabCut projects, the above command is equivalent to: +Because LightningPose follows the single-animal +DeepLabCut .csv format, the above command is equivalent to: ```python save_poses.to_dlc_file(ds, "/path/to/file.csv", split_individuals=True) ``` @@ -157,3 +227,33 @@ save_poses.to_dlc_file(ds, "/path/to/file.csv", split_individuals=True) :::: ::::: + + +(target-saving-bboxes-tracks)= +## Saving bounding boxes' tracks + +We currently do not provide explicit methods to export a movement bounding boxes dataset in a specific format. However, you can easily save the bounding boxes' trajectories to a .csv file using the standard Python library `csv`. + +Here is an example of how you can save a bounding boxes dataset to a .csv file: + +```python +# define name for output csv file +file = 'tracking_output.csv" + +# open the csv file in write mode +with open(filepath, mode="w", newline="") as file: + writer = csv.writer(file) + + # write the header + writer.writerow(["frame_idx", "bbox_ID", "x", "y", "width", "height", "confidence"]) + + # write the data + for individual in ds.individuals.data: + for frame in ds.time.data: + x, y = ds.position.sel(time=frame, individuals=individual).data + width, height = ds.shape.sel(time=frame, individuals=individual).data + confidence = ds.confidence.sel(time=frame, individuals=individual).data + writer.writerow([frame, individual, x, y, width, height, confidence]) + +``` +Alternatively, we can convert the `movement` bounding boxes' dataset to a pandas DataFrame with the {func}`.xarray.DataArray.to_dataframe()` method, wrangle the dataframe as required, and then apply the {func}`.pandas.DataFrame.to_csv()` method to save the data as a .csv file. diff --git a/docs/source/getting_started/installation.md b/docs/source/getting_started/installation.md index d4f51123..404acceb 100644 --- a/docs/source/getting_started/installation.md +++ b/docs/source/getting_started/installation.md @@ -1,68 +1,62 @@ (target-installation)= # Installation -## Create a conda environment - +## Install the package :::{admonition} Use a conda environment :class: note -We recommend you install movement inside a [conda](conda:) -or [mamba](mamba:) environment, to avoid dependency conflicts with other packages. -In the following we assume you have `conda` installed, -but the same commands will also work with `mamba`/`micromamba`. +To avoid dependency conflicts with other packages, it is best practice to install Python packages within a virtual environment. +We recommend using [conda](conda:) or [mamba](mamba:) to create and manage this environment, as they simplify the installation process. +The following instructions assume that you have conda installed, but the same commands will also work with `mamba`/`micromamba`. ::: -First, create and activate an environment with some prerequisites. -You can call your environment whatever you like, we've used `movement-env`. +### Users +To install movement in a new environment, follow one of the options below. +We will use `movement-env` as the environment name, but you can choose any name you prefer. +::::{tab-set} +:::{tab-item} Conda +Create and activate an environment with movement installed: ```sh -conda create -n movement-env -c conda-forge python=3.10 pytables +conda create -n movement-env -c conda-forge movement conda activate movement-env ``` - -## Install the package - -Then install the `movement` package as described below. - -::::{tab-set} - -:::{tab-item} Users -To get the latest release from PyPI: - +::: +:::{tab-item} Pip +Create and activate an environment with some prerequisites: ```sh -pip install movement +conda create -n movement-env -c conda-forge python=3.11 pytables +conda activate movement-env ``` -If you have an older version of `movement` installed in the same environment, -you can update to the latest version with: - +Install the latest movement release from PyPI: ```sh -pip install --upgrade movement +pip install movement ``` ::: +:::: -:::{tab-item} Developers -To get the latest development version, clone the -[GitHub repository](movement-github:) -and then run from inside the repository: +### Developers +If you are a developer looking to contribute to movement, please refer to our [contributing guide](target-contributing) for detailed setup instructions and guidelines. +## Check the installation +To verify that the installation was successful, run (with `movement-env` activated): ```sh -pip install -e .[dev] # works on most shells -pip install -e '.[dev]' # works on zsh (the default shell on macOS) +movement info ``` +You should see a printout including the version numbers of movement +and some of its dependencies. -This will install the package in editable mode, including all `dev` dependencies. -Please see the [contributing guide](target-contributing) for more information. -::: - -:::: - -## Check the installation - -To verify that the installation was successful, you can run the following -command (with the `movement-env` activated): +## Update the package +To update movement to the latest version, we recommend installing it in a new environment, +as this prevents potential compatibility issues caused by changes in dependency versions. +To uninstall an existing environment named `movement-env`: ```sh -movement info +conda env remove -n movement-env ``` - -You should see a printout including the version numbers of `movement` -and some of its dependencies. +:::{tip} +If you are unsure about the environment name, you can get a list of the environments on your system with: +```sh +conda env list +``` +::: +Once the environment has been removed, you can create a new one following the [installation instructions](#install-the-package) above. diff --git a/docs/source/getting_started/movement_dataset.md b/docs/source/getting_started/movement_dataset.md index ffadc150..c13128d2 100644 --- a/docs/source/getting_started/movement_dataset.md +++ b/docs/source/getting_started/movement_dataset.md @@ -1,94 +1,201 @@ -(target-dataset)= -# The movement dataset +(target-poses-and-bboxes-dataset)= +# The movement datasets -When you load predicted pose tracks into `movement`, they are represented -as an {class}`xarray.Dataset` object, which is a container for multiple data -arrays. Each array is in turn represented as an {class}`xarray.DataArray` -object, which you can think of as a multi-dimensional {class}`numpy.ndarray` +In `movement`, poses or bounding boxes' tracks are represented +as an {class}`xarray.Dataset` object. + +An {class}`xarray.Dataset` object is a container for multiple arrays. Each array is an {class}`xarray.DataArray` object holding different aspects of the collected data (position, time, confidence scores...). You can think of a {class}`xarray.DataArray` object as a multi-dimensional {class}`numpy.ndarray` with pandas-style indexing and labelling. -So, a `movement` dataset is simply an {class}`xarray.Dataset` with a specific -structure to represent pose tracks, associated confidence scores and relevant -metadata. Each dataset consists of **data variables**, **dimensions**, -**coordinates** and **attributes**. +So a `movement` dataset is simply an {class}`xarray.Dataset` with a specific +structure to represent pose tracks or bounding boxes' tracks. Because pose data and bounding boxes data are somewhat different, `movement` provides two types of datasets: `poses` datasets and `bboxes` datasets. + +To discuss the specifics of both types of `movement` datasets, it is useful to clarify some concepts such as **data variables**, **dimensions**, +**coordinates** and **attributes**. In the next section, we will describe these concepts and the `movement` datasets' structure in some detail. -In the next section, we will describe the -structure of a `movement` dataset in some detail. To learn more about `xarray` data structures in general, see the relevant [documentation](xarray:user-guide/data-structures.html). ## Dataset structure -![](../_static/dataset_structure.png) +```{figure} ../_static/dataset_structure.png +:alt: movement dataset structure + +An {class}`xarray.Dataset` is a collection of several data arrays that share some dimensions. The schematic shows the data arrays that make up the `poses` and `bboxes` datasets in `movement`. +``` + +The structure of a `movement` dataset `ds` can be easily inspected by simply +printing it. + +::::{tab-set} + +:::{tab-item} Poses dataset +To inspect a sample poses dataset, we can run: +```python +from movement import sample_data + +ds = sample_data.fetch_dataset( + "SLEAP_three-mice_Aeon_proofread.analysis.h5", +) +print(ds) +``` + +and we would obtain an output such as: +``` + Size: 27kB +Dimensions: (time: 601, individuals: 3, keypoints: 1, space: 2) +Coordinates: + * time (time) float64 5kB 0.0 0.02 0.04 0.06 ... 11.96 11.98 12.0 + * individuals (individuals) Size: 19kB +Dimensions: (time: 5, individuals: 86, space: 2) +Coordinates: + * time (time) int64 40B 0 1 2 3 4 + * individuals (individuals) `. +# For example, the path could be something like: + +# uncomment and edit the following line to point to your own local file +# file_path = "/path/to/my/data.h5" + +# %% +# For the sake of this example, we will use the path to one of +# the sample datasets provided with ``movement``. + +file_path = sample_data.fetch_dataset_paths( + "SLEAP_single-mouse_EPM.analysis.h5" +)["poses"] +print(file_path) + +# %% +# Now let's load this file into a +# :ref:`movement poses dataset`, +# which we can then modify to our liking. + +ds = load_poses.from_sleap_file(file_path, fps=30) +print(ds, "\n") +print("Individuals:", ds.coords["individuals"].values) +print("Keypoints:", ds.coords["keypoints"].values) + + +# %% +# .. note:: +# If you're running this code in a Jupyter notebook, +# you can just type ``ds`` (instead of printing it) +# to explore the dataset interactively. + +# %% +# Rename keypoints +# ---------------- +# We start with a dictionary that maps old keypoint names to new ones. +# Next, we define a function that takes that dictionary and a dataset +# as inputs, and returns a modified dataset. Notice that under the hood +# this function calls :meth:`xarray.Dataset.assign_coords`. + +rename_dict = { + "snout": "nose", + "left_ear": "earL", + "right_ear": "earR", + "centre": "middle", + "tail_base": "tailbase", + "tail_end": "tailend", +} + + +def rename_keypoints(ds, rename_dict): + # get the current names of the keypoints + keypoint_names = ds.coords["keypoints"].values + + # rename the keypoints + if not rename_dict: + print("No keypoints to rename. Skipping renaming step.") + else: + new_keypoints = [rename_dict.get(kp, str(kp)) for kp in keypoint_names] + # Assign the modified values back to the Dataset + ds = ds.assign_coords(keypoints=new_keypoints) + return ds + + +# %% +# Let's apply the function to our dataset and see the results. +ds_renamed = rename_keypoints(ds, rename_dict) +print("Keypoints in modified dataset:", ds_renamed.coords["keypoints"].values) + + +# %% +# Delete keypoints +# ----------------- +# Let's create a list of keypoints to delete. +# In this case, we choose to get rid of the ``tailend`` keypoint, +# which is often hard to reliably track. +# We delete it using :meth:`xarray.Dataset.drop_sel`, +# wrapped in an appropriately named function. + +keypoints_to_delete = ["tailend"] + + +def delete_keypoints(ds, delete_keypoints): + if not delete_keypoints: + print("No keypoints to delete. Skipping deleting step.") + else: + # Delete the specified keypoints and their corresponding data + ds = ds.drop_sel(keypoints=delete_keypoints) + return ds + + +ds_deleted = delete_keypoints(ds_renamed, keypoints_to_delete) +print("Keypoints in modified dataset:", ds_deleted.coords["keypoints"].values) + + +# %% +# Reorder keypoints +# ------------------ +# We start with a list of keypoints in the desired order +# (in this case, we'll just swap the order of the left and right ears). +# We then use :meth:`xarray.Dataset.reindex`, wrapped in yet another function. + +ordered_keypoints = ["nose", "earR", "earL", "middle", "tailbase"] + + +def reorder_keypoints(ds, ordered_keypoints): + if not ordered_keypoints: + print("No keypoints to reorder. Skipping reordering step.") + else: + ds = ds.reindex(keypoints=ordered_keypoints) + return ds + + +ds_reordered = reorder_keypoints(ds_deleted, ordered_keypoints) +print( + "Keypoints in modified dataset:", ds_reordered.coords["keypoints"].values +) + +# %% +# Save the modified dataset +# --------------------------- +# Now that we have modified the dataset to our liking, +# let's save it to a .csv file in the DeepLabCut format. +# In this case, we save the file to a temporary +# directory, and we use the same file name +# as the original, but ending in ``_dlc.csv``. +# You will need to specify a different ``target_dir`` and edit +# the ``dest_path`` variable to your liking. + +target_dir = tempfile.mkdtemp() +dest_path = Path(target_dir) / f"{file_path.stem}_dlc.csv" + +save_poses.to_dlc_file(ds_reordered, dest_path, split_individuals=False) +print(f"Saved modified dataset to {dest_path}.") + +# %% +# .. note:: +# The ``split_individuals`` argument allows you to save +# a dataset with multiple individuals as separate files, +# with the individual ID appended to each file name. +# In this case, we set it to ``False`` because we only have +# one individual in the dataset, and we don't need its name +# appended to the file name. + + +# %% +# One function to rule them all +# ----------------------------- +# Since we know how to rename, delete, and reorder keypoints, +# let's put it all together in a single function +# and see how we could apply it to multiple files at once, +# as we might do in a real-world scenario. +# +# The following function will convert all files in a folder +# (that end with a specified suffix) from SLEAP to DeepLabCut format. +# Each file will be loaded, modified according to the +# ``rename_dict``, ``keypoints_to_delete``, and ``ordered_keypoints`` +# we've defined above, and saved to the target directory. + + +data_dir = "/path/to/your/data/" +target_dir = "/path/to/your/target/data/" + + +def convert_all(data_dir, target_dir, suffix=".slp"): + source_folder = Path(data_dir) + file_paths = list(source_folder.rglob(f"*{suffix}")) + + for file_path in file_paths: + file_path = Path(file_path) + + # this determines the file names for the modified files + dest_path = Path(target_dir) / f"{file_path.stem}_dlc.csv" + + if dest_path.exists(): + print(f"Skipping {file_path} as {dest_path} already exists.") + continue + + if file_path.exists(): + print(f"Processing: {file_path}") + # load the data from SLEAP file + ds = load_poses.from_sleap_file(file_path) + # modify the data + ds_renamed = rename_keypoints(ds, rename_dict) + ds_deleted = delete_keypoints(ds_renamed, keypoints_to_delete) + ds_reordered = reorder_keypoints(ds_deleted, ordered_keypoints) + # save modified data to a DeepLabCut file + save_poses.to_dlc_file( + ds_reordered, dest_path, split_individuals=False + ) + else: + raise ValueError( + f"File '{file_path}' does not exist. " + f"Please check the file path and try again." + ) diff --git a/examples/filter_and_interpolate.py b/examples/filter_and_interpolate.py index b3461dc2..baa17e99 100644 --- a/examples/filter_and_interpolate.py +++ b/examples/filter_and_interpolate.py @@ -1,5 +1,5 @@ -"""Filtering and interpolation -============================ +"""Drop outliers and interpolate +================================ Filter out points with low confidence scores and interpolate over missing values. @@ -10,6 +10,7 @@ # ------- from movement import sample_data from movement.filtering import filter_by_confidence, interpolate_over_time +from movement.kinematics import compute_velocity # %% # Load a sample dataset @@ -19,16 +20,21 @@ print(ds) # %% -# We can see that this dataset contains the 2D pose tracks and confidence -# scores for a single wasp, generated with DeepLabCut. There are 2 keypoints: -# "head" and "stinger". +# We see that the dataset contains the 2D pose tracks and confidence scores +# for a single wasp, generated with DeepLabCut. The wasp is tracked at two +# keypoints: "head" and "stinger" in a video that was recorded at 40 fps and +# lasts for approximately 27 seconds. # %% # Visualise the pose tracks # ------------------------- +# Since the data contains only a single wasp, we use +# :meth:`xarray.DataArray.squeeze` to remove +# the dimension of length 1 from the data (the ``individuals`` dimension). -position = ds.position.sel(individuals="individual_0") -position.plot.line(x="time", row="keypoints", hue="space", aspect=2, size=2.5) +ds.position.squeeze().plot.line( + x="time", row="keypoints", hue="space", aspect=2, size=2.5 +) # %% # We can see that the pose tracks contain some implausible "jumps", such @@ -46,70 +52,77 @@ # estimation frameworks, and their ranges can vary. Therefore, # it's always a good idea to inspect the actual confidence values in the data. # -# Let's first look at a histogram of the confidence scores. -ds.confidence.plot.hist(bins=20) +# Let's first look at a histogram of the confidence scores. As before, we use +# :meth:`xarray.DataArray.squeeze` to remove the ``individuals`` dimension +# from the data. + +ds.confidence.squeeze().plot.hist(bins=20) # %% # Based on the above histogram, we can confirm that the confidence scores # indeed range between 0 and 1, with most values closer to 1. Now let's see how # they evolve over time. -confidence = ds.confidence.sel(individuals="individual_0") -confidence.plot.line(x="time", row="keypoints", aspect=2, size=2.5) +ds.confidence.squeeze().plot.line( + x="time", row="keypoints", aspect=2, size=2.5 +) # %% # Encouragingly, some of the drops in confidence scores do seem to correspond # to the implausible jumps and spikes we had seen in the position. # We can use that to our advantage. - # %% # Filter out points with low confidence # ------------------------------------- -# We can filter out points with confidence scores below a certain threshold. -# Here, we use ``threshold=0.6``. Points in the ``position`` data variable -# with confidence scores below this threshold will be converted to NaN. -# The ``print_report`` argument, which is True by default, reports the number -# of NaN values in the dataset before and after the filtering operation. +# Using the :func:`movement.filtering.filter_by_confidence` function from the +# :mod:`movement.filtering` module, we can filter out points with confidence +# scores below a certain threshold. This function takes ``position`` and +# ``confidence`` as required arguments, and accepts an optional ``threshold`` +# parameter, which defaults to ``threshold=0.6`` unless specified otherwise. +# The function will also report the number of NaN values in the dataset before +# and after the filtering operation by default, but you can disable this +# by passing ``print_report=False``. +# +# We will use :meth:`xarray.Dataset.update` to update ``ds`` in-place +# with the filtered ``position``. -ds_filtered = filter_by_confidence(ds, threshold=0.6, print_report=True) +ds.update({"position": filter_by_confidence(ds.position, ds.confidence)}) # %% # We can see that the filtering operation has introduced NaN values in the # ``position`` data variable. Let's visualise the filtered data. -position_filtered = ds_filtered.position.sel(individuals="individual_0") -position_filtered.plot.line( +ds.position.squeeze().plot.line( x="time", row="keypoints", hue="space", aspect=2, size=2.5 ) # %% -# Here we can see that gaps have appeared in the pose tracks, some of which -# are over the implausible jumps and spikes we had seen earlier. Moreover, -# most gaps seem to be brief, lasting < 1 second. +# Here we can see that gaps (consecutive NaNs) have appeared in the +# pose tracks, some of which are over the implausible jumps and spikes we had +# seen earlier. Moreover, most gaps seem to be brief, +# lasting < 1 second (or 40 frames). # %% # Interpolate over missing values # ------------------------------- -# We can interpolate over the gaps we've introduced in the pose tracks. -# Here we use the default linear interpolation method and ``max_gap=1``, -# meaning that we will only interpolate over gaps of 1 second or shorter. -# Setting ``max_gap=None`` would interpolate over all gaps, regardless of -# their length, which should be used with caution as it can introduce +# Using the :func:`movement.filtering.interpolate_over_time` function from the +# :mod:`movement.filtering` module, we can interpolate over gaps +# we've introduced in the pose tracks. +# Here we use the default linear interpolation method (``method=linear``) +# and interpolate over gaps of 40 frames or less (``max_gap=40``). +# The default ``max_gap=None`` would interpolate over all gaps, regardless of +# their length, but this should be used with caution as it can introduce # spurious data. The ``print_report`` argument acts as described above. -ds_interpolated = interpolate_over_time( - ds_filtered, method="linear", max_gap=1, print_report=True -) +ds.update({"position": interpolate_over_time(ds.position, max_gap=40)}) # %% # We see that all NaN values have disappeared, meaning that all gaps were -# indeed shorter than 1 second. Let's visualise the interpolated pose tracks +# indeed shorter than 40 frames. +# Let's visualise the interpolated pose tracks. -position_interpolated = ds_interpolated.position.sel( - individuals="individual_0" -) -position_interpolated.plot.line( +ds.position.squeeze().plot.line( x="time", row="keypoints", hue="space", aspect=2, size=2.5 ) @@ -119,9 +132,35 @@ # So, far we've processed the pose tracks first by filtering out points with # low confidence scores, and then by interpolating over missing values. # The order of these operations and the parameters with which they were -# performed are saved in the ``log`` attribute of the dataset. +# performed are saved in the ``log`` attribute of the ``position`` data array. # This is useful for keeping track of the processing steps that have been -# applied to the data. +# applied to the data. Let's inspect the log entries. -for log_entry in ds_interpolated.log: +for log_entry in ds.position.log: print(log_entry) + +# %% +# Filtering multiple data variables +# --------------------------------- +# We can also apply the same filtering operation to +# multiple data variables in ``ds`` at the same time. +# +# For instance, to filter both ``position`` and ``velocity`` data variables +# in ``ds``, based on the confidence scores, we can specify a dictionary +# with the data variable names as keys and the corresponding filtered +# DataArrays as values. Then we can once again use +# :meth:`xarray.Dataset.update` to update ``ds`` in-place +# with the filtered data variables. + +# Add velocity data variable to the dataset +ds["velocity"] = compute_velocity(ds.position) + +# Create a dictionary mapping data variable names to filtered DataArrays +# We disable report printing for brevity +update_dict = { + var: filter_by_confidence(ds[var], ds.confidence, print_report=False) + for var in ["position", "velocity"] +} + +# Use the dictionary to update the dataset in-place +ds.update(update_dict) diff --git a/examples/smooth.py b/examples/smooth.py new file mode 100644 index 00000000..f87ac411 --- /dev/null +++ b/examples/smooth.py @@ -0,0 +1,344 @@ +"""Smooth pose tracks +===================== + +Smooth pose tracks using the median and Savitzky-Golay filters. +""" + +# %% +# Imports +# ------- + +from matplotlib import pyplot as plt +from scipy.signal import welch + +from movement import sample_data +from movement.filtering import ( + interpolate_over_time, + median_filter, + savgol_filter, +) + +# %% +# Load a sample dataset +# --------------------- +# Let's load a sample dataset and print it to inspect its contents. +# Note that if you are running this notebook interactively, you can simply +# type the variable name (here ``ds_wasp``) in a cell to get an interactive +# display of the dataset's contents. + +ds_wasp = sample_data.fetch_dataset("DLC_single-wasp.predictions.h5") +print(ds_wasp) + +# %% +# We see that the dataset contains the 2D pose tracks and confidence scores +# for a single wasp, generated with DeepLabCut. The wasp is tracked at two +# keypoints: "head" and "stinger" in a video that was recorded at 40 fps and +# lasts for approximately 27 seconds. + +# %% +# Define a plotting function +# -------------------------- +# Let's define a plotting function to help us visualise the effects of +# smoothing both in the time and frequency domains. +# The function takes as inputs two datasets containing raw and smooth data +# respectively, and plots the position time series and power spectral density +# (PSD) for a given individual and keypoint. The function also allows you to +# specify the spatial coordinate (``x`` or ``y``) and a time range to focus on. + + +def plot_raw_and_smooth_timeseries_and_psd( + ds_raw, + ds_smooth, + individual="individual_0", + keypoint="stinger", + space="x", + time_range=None, +): + # If no time range is specified, plot the entire time series + if time_range is None: + time_range = slice(0, ds_raw.time[-1]) + + selection = { + "time": time_range, + "individuals": individual, + "keypoints": keypoint, + "space": space, + } + + fig, ax = plt.subplots(2, 1, figsize=(10, 6)) + + for ds, color, label in zip( + [ds_raw, ds_smooth], ["k", "r"], ["raw", "smooth"], strict=False + ): + # plot position time series + pos = ds.position.sel(**selection) + ax[0].plot( + pos.time, + pos, + color=color, + lw=2, + alpha=0.7, + label=f"{label} {space}", + ) + + # interpolate data to remove NaNs in the PSD calculation + pos_interp = interpolate_over_time(pos, print_report=False) + + # compute and plot the PSD + freq, psd = welch(pos_interp, fs=ds.fps, nperseg=256) + ax[1].semilogy( + freq, + psd, + color=color, + lw=2, + alpha=0.7, + label=f"{label} {space}", + ) + + ax[0].set_ylabel(f"{space} position (px)") + ax[0].set_xlabel("Time (s)") + ax[0].set_title("Time Domain") + ax[0].legend() + + ax[1].set_ylabel("PSD (px$^2$/Hz)") + ax[1].set_xlabel("Frequency (Hz)") + ax[1].set_title("Frequency Domain") + ax[1].legend() + + plt.tight_layout() + fig.show() + + +# %% +# Smoothing with a median filter +# ------------------------------ +# Using the :func:`movement.filtering.median_filter` function on the +# ``position`` data variable, we can apply a rolling window median filter +# over a 0.1-second window (4 frames) to the wasp dataset. +# As the ``window`` parameter is defined in *number of observations*, +# we can simply multiply the desired time window by the frame rate +# of the video. We will also create a copy of the dataset to avoid +# modifying the original data. + +window = int(0.1 * ds_wasp.fps) +ds_wasp_smooth = ds_wasp.copy() +ds_wasp_smooth.update({"position": median_filter(ds_wasp.position, window)}) + +# %% +# We see from the printed report that the dataset has no missing values +# neither before nor after smoothing. Let's visualise the effects of the +# median filter in the time and frequency domains. + +plot_raw_and_smooth_timeseries_and_psd( + ds_wasp, ds_wasp_smooth, keypoint="stinger" +) + +# %% +# We see that the median filter has removed the "spikes" present around the +# 14 second mark in the raw data. However, it has not dealt the big shift +# occurring during the final second. In the frequency domain, we can see that +# the filter has reduced the power in the high-frequency components, without +# affecting the low frequency components. +# +# This illustrates what the median filter is good at: removing brief "spikes" +# (e.g. a keypoint abruptly jumping to a different location for a frame or two) +# and high-frequency "jitter" (often present due to pose estimation +# working on a per-frame basis). + +# %% +# Choosing parameters for the median filter +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# We can control the behaviour of the median filter +# via two parameters: ``window`` and ``min_periods``. +# To better understand the effect of these parameters, let's use a +# dataset that contains missing values. + +ds_mouse = sample_data.fetch_dataset("SLEAP_single-mouse_EPM.analysis.h5") +print(ds_mouse) + +# %% +# The dataset contains a single mouse with six keypoints tracked in +# 2D space. The video was recorded at 30 fps and lasts for ~616 seconds. We can +# see that there are some missing values, indicated as "nan" in the +# printed dataset. +# Let's apply the median filter over a 0.1-second window (3 frames) +# to the dataset. + +window = int(0.1 * ds_mouse.fps) +ds_mouse_smooth = ds_mouse.copy() +ds_mouse_smooth.update({"position": median_filter(ds_mouse.position, window)}) + +# %% +# The report informs us that the raw data contains NaN values, most of which +# occur at the ``snout`` and ``tail_end`` keypoints. After filtering, the +# number of NaNs has increased. This is because the default behaviour of the +# median filter is to propagate NaN values, i.e. if any value in the rolling +# window is NaN, the output will also be NaN. +# +# To modify this behaviour, you can set the value of the ``min_periods`` +# parameter to an integer value. This parameter determines the minimum number +# of non-NaN values required in the window for the output to be non-NaN. +# For example, setting ``min_periods=2`` means that two non-NaN values in the +# window are sufficient for the median to be calculated. Let's try this. + +ds_mouse_smooth.update( + {"position": median_filter(ds_mouse.position, window, min_periods=2)} +) + +# %% +# We see that this time the number of NaN values has decreased +# across all keypoints. +# Let's visualise the effects of the median filter in the time and frequency +# domains. Here we focus on the first 80 seconds for the ``snout`` keypoint. +# You can adjust the ``keypoint`` and ``time_range`` arguments to explore other +# parts of the data. + +plot_raw_and_smooth_timeseries_and_psd( + ds_mouse, ds_mouse_smooth, keypoint="snout", time_range=slice(0, 80) +) + +# %% +# The smoothing once again reduces the power of high-frequency components, but +# the resulting time series stays quite close to the raw data. +# +# What happens if we increase the ``window`` to 2 seconds (60 frames)? + +window = int(2 * ds_mouse.fps) +ds_mouse_smooth.update( + {"position": median_filter(ds_mouse.position, window, min_periods=2)} +) + +# %% +# The number of NaN values has decreased even further. +# That's because the chance of finding at least 2 valid values within +# a 2-second window (i.e. 60 frames) is quite high. +# Let's plot the results for the same keypoint and time range +# as before. + +plot_raw_and_smooth_timeseries_and_psd( + ds_mouse, ds_mouse_smooth, keypoint="snout", time_range=slice(0, 80) +) +# %% +# We see that the filtered time series is much smoother and it has even +# "bridged" over some small gaps. That said, it often deviates from the raw +# data, in ways that may not be desirable, depending on the application. +# Here, our choice of ``window`` may be too large. +# In general, you should choose a ``window`` that is small enough to +# preserve the original data structure, but large enough to remove +# "spikes" and high-frequency noise. Always inspect the results to ensure +# that the filter is not removing important features. + +# %% +# Smoothing with a Savitzky-Golay filter +# -------------------------------------- +# Here we apply the :func:`movement.filtering.savgol_filter` function +# (a wrapper around :func:`scipy.signal.savgol_filter`), to the ``position`` +# data variable. +# The Savitzky-Golay filter is a polynomial smoothing filter that can be +# applied to time series data on a rolling window basis. +# A polynomial with a degree specified by ``polyorder`` is applied to each +# data segment defined by the size ``window``. +# The value of the polynomial at the midpoint of each ``window`` is then +# used as the output value. +# +# Let's try it on the mouse dataset, this time using a 0.2-second +# window (i.e. 6 frames) and the default ``polyorder=2`` for smoothing. +# As before, we first compute the corresponding number of observations +# to be used as the ``window`` size. + +window = int(0.2 * ds_mouse.fps) +ds_mouse_smooth.update({"position": savgol_filter(ds_mouse.position, window)}) + +# %% +# We see that the number of NaN values has increased after filtering. This is +# for the same reason as with the median filter (in its default mode), i.e. +# if there is at least one NaN value in the window, the output will be NaN. +# Unlike the median filter, the Savitzky-Golay filter does not provide a +# ``min_periods`` parameter to control this behaviour. Let's visualise the +# effects in the time and frequency domains. + +plot_raw_and_smooth_timeseries_and_psd( + ds_mouse, ds_mouse_smooth, keypoint="snout", time_range=slice(0, 80) +) +# %% +# Once again, the power of high-frequency components has been reduced, but more +# missing values have been introduced. + +# %% +# Now let's apply the same Savitzky-Golay filter to the wasp dataset. + +window = int(0.2 * ds_wasp.fps) +ds_wasp_smooth.update({"position": savgol_filter(ds_wasp.position, window)}) + +# %% +plot_raw_and_smooth_timeseries_and_psd( + ds_wasp, ds_wasp_smooth, keypoint="stinger" +) +# %% +# This example shows two important limitations of the Savitzky-Golay filter. +# First, the filter can introduce artefacts around sharp boundaries. For +# example, focus on what happens around the sudden drop in position +# during the final second. Second, the PSD appears to have large periodic +# drops at certain frequencies. Both of these effects vary with the +# choice of ``window`` and ``polyorder``. You can read more about these +# and other limitations of the Savitzky-Golay filter in +# `this paper `_. + + +# %% +# Combining multiple smoothing filters +# ------------------------------------ +# We can also combine multiple smoothing filters by applying them +# sequentially. For example, we can first apply the median filter with a small +# ``window`` to remove "spikes" and then apply the Savitzky-Golay filter +# with a larger ``window`` to further smooth the data. +# Between the two filters, we can interpolate over small gaps to avoid the +# excessive proliferation of NaN values. Let's try this on the mouse dataset. + +# First, we will apply the median filter. +window = int(0.1 * ds_mouse.fps) +ds_mouse_smooth.update( + {"position": median_filter(ds_mouse.position, window, min_periods=2)} +) + +# Next, let's linearly interpolate over gaps smaller +# than 1 second (30 frames). +ds_mouse_smooth.update( + {"position": interpolate_over_time(ds_mouse_smooth.position, max_gap=30)} +) + +# Finally, let's apply the Savitzky-Golay filter +# over a 0.4-second window (12 frames). +window = int(0.4 * ds_mouse.fps) +ds_mouse_smooth.update( + {"position": savgol_filter(ds_mouse_smooth.position, window)} +) + +# %% +# A record of all applied operations is stored in the ``log`` attribute of the +# ``ds_mouse_smooth.position`` data array. Let's inspect it to summarise +# what we've done. + +for entry in ds_mouse_smooth.position.log: + print(entry) + +# %% +# Now let's visualise the difference between the raw data and the final +# smoothed result. + +plot_raw_and_smooth_timeseries_and_psd( + ds_mouse, + ds_mouse_smooth, + keypoint="snout", + time_range=slice(0, 80), +) + +# %% +# Feel free to play around with the parameters of the applied filters and to +# also look at other keypoints and time ranges. + +# %% +# .. seealso:: +# :ref:`examples/filter_and_interpolate:Filtering multiple data variables` +# in the +# :ref:`sphx_glr_examples_filter_and_interpolate.py` example. diff --git a/movement/__init__.py b/movement/__init__.py index 7f6efda3..bf5d4a2d 100644 --- a/movement/__init__.py +++ b/movement/__init__.py @@ -1,7 +1,6 @@ from importlib.metadata import PackageNotFoundError, version -from movement.logging import configure_logging -from movement.move_accessor import MovementDataset +from movement.utils.logging import configure_logging try: __version__ = version("movement") diff --git a/movement/analysis/kinematics.py b/movement/analysis/kinematics.py deleted file mode 100644 index a1241b29..00000000 --- a/movement/analysis/kinematics.py +++ /dev/null @@ -1,134 +0,0 @@ -"""Functions for computing kinematic variables.""" - -import numpy as np -import xarray as xr - -from movement.logging import log_error - - -def compute_displacement(data: xr.DataArray) -> xr.DataArray: - """Compute displacement between consecutive positions. - - This is the difference between consecutive positions of each keypoint for - each individual across time. At each time point ``t``, it's defined as a - vector in cartesian ``(x,y)`` coordinates, pointing from the previous - ``(t-1)`` to the current ``(t)`` position. - - Parameters - ---------- - data : xarray.DataArray - The input data containing ``time`` as a dimension. - - Returns - ------- - xarray.DataArray - An xarray DataArray containing the computed displacement. - - """ - _validate_time_dimension(data) - result = data.diff(dim="time") - result = result.reindex(data.coords, fill_value=0) - return result - - -def compute_velocity(data: xr.DataArray) -> xr.DataArray: - """Compute the velocity in cartesian ``(x,y)`` coordinates. - - Velocity is the first derivative of position for each keypoint - and individual across time. It's computed using numerical differentiation - and assumes equidistant time spacing. - - Parameters - ---------- - data : xarray.DataArray - The input data containing ``time`` as a dimension. - - Returns - ------- - xarray.DataArray - An xarray DataArray containing the computed velocity. - - """ - return _compute_approximate_derivative(data, order=1) - - -def compute_acceleration(data: xr.DataArray) -> xr.DataArray: - """Compute acceleration in cartesian ``(x,y)`` coordinates. - - Acceleration represents the second derivative of position for each keypoint - and individual across time. It's computed using numerical differentiation - and assumes equidistant time spacing. - - Parameters - ---------- - data : xarray.DataArray - The input data containing ``time`` as a dimension. - - Returns - ------- - xarray.DataArray - An xarray DataArray containing the computed acceleration. - - """ - return _compute_approximate_derivative(data, order=2) - - -def _compute_approximate_derivative( - data: xr.DataArray, order: int -) -> xr.DataArray: - """Compute the derivative using numerical differentiation. - - This assumes equidistant time spacing. - - Parameters - ---------- - data : xarray.DataArray - The input data containing ``time`` as a dimension. - order : int - The order of the derivative. 1 for velocity, 2 for - acceleration. Value must be a positive integer. - - Returns - ------- - xarray.DataArray - An xarray DataArray containing the derived variable. - - """ - if not isinstance(order, int): - raise log_error( - TypeError, f"Order must be an integer, but got {type(order)}." - ) - if order <= 0: - raise log_error(ValueError, "Order must be a positive integer.") - _validate_time_dimension(data) - result = data - dt = data["time"].values[1] - data["time"].values[0] - for _ in range(order): - result = xr.apply_ufunc( - np.gradient, - result, - dt, - kwargs={"axis": 0}, - ) - result = result.reindex_like(data) - return result - - -def _validate_time_dimension(data: xr.DataArray) -> None: - """Validate the input data contains a ``time`` dimension. - - Parameters - ---------- - data : xarray.DataArray - The input data to validate. - - Raises - ------ - ValueError - If the input data does not contain a ``time`` dimension. - - """ - if "time" not in data.dims: - raise log_error( - ValueError, "Input data must contain 'time' as a dimension." - ) diff --git a/movement/cli_entrypoint.py b/movement/cli_entrypoint.py index f760ba87..8be03137 100644 --- a/movement/cli_entrypoint.py +++ b/movement/cli_entrypoint.py @@ -9,7 +9,7 @@ import movement -ASCII_ART = """ +ASCII_ART = r""" _ __ ___ _____ _____ _ __ ___ ___ _ __ | |_ | '_ ` _ \ / _ \ \ / / _ \ '_ ` _ \ / _ \ '_ \| __| | | | | | | (_) \ V / __/ | | | | | __/ | | | |_ diff --git a/movement/filtering.py b/movement/filtering.py index 451830e8..8432c133 100644 --- a/movement/filtering.py +++ b/movement/filtering.py @@ -1,214 +1,156 @@ -"""Functions for filtering and interpolating pose tracks in xarray datasets.""" - -import logging -from datetime import datetime -from functools import wraps -from typing import Optional, Union +"""Filter and interpolate tracks in ``movement`` datasets.""" import xarray as xr from scipy import signal -from movement.logging import log_error - - -def log_to_attrs(func): - """Log the operation performed by the wrapped function. - - This decorator appends log entries to the xarray.Dataset's "log" attribute. - For the decorator to work, the wrapped function must accept an - xarray.Dataset as its first argument and return an xarray.Dataset. - """ - - @wraps(func) - def wrapper(*args, **kwargs): - result = func(*args, **kwargs) - - log_entry = { - "operation": func.__name__, - "datetime": str(datetime.now()), - **{f"arg_{i}": arg for i, arg in enumerate(args[1:], start=1)}, - **kwargs, - } - - # Append the log entry to the result's attributes - if result is not None and hasattr(result, "attrs"): - if "log" not in result.attrs: - result.attrs["log"] = [] - result.attrs["log"].append(log_entry) - - return result - - return wrapper - - -def report_nan_values(ds: xr.Dataset, ds_label: str = "dataset"): - """Report the number and percentage of points that are NaN. - - Numbers are reported for each individual and keypoint in the dataset. - - Parameters - ---------- - ds : xarray.Dataset - Dataset containing position, confidence scores, and metadata. - ds_label : str - Label to identify the dataset in the report. Default is "dataset". - - """ - # Compile the report - nan_report = f"\nMissing points (marked as NaN) in {ds_label}:" - for ind in ds.individuals.values: - nan_report += f"\n\tIndividual: {ind}" - for kp in ds.keypoints.values: - # Get the position track for the current individual and keypoint - position = ds.position.sel(individuals=ind, keypoints=kp) - # A point is considered NaN if any of its space coordinates are NaN - n_nans = position.isnull().any(["space"]).sum(["time"]).item() - n_points = position.time.size - percent_nans = round((n_nans / n_points) * 100, 1) - nan_report += f"\n\t\t{kp}: {n_nans}/{n_points} ({percent_nans}%)" - - # Write nan report to logger - logger = logging.getLogger(__name__) - logger.info(nan_report) - # Also print the report to the console - print(nan_report) - return None +from movement.utils.logging import log_error, log_to_attrs +from movement.utils.reports import report_nan_values @log_to_attrs -def interpolate_over_time( - ds: xr.Dataset, - method: str = "linear", - max_gap: Union[int, None] = None, +def filter_by_confidence( + data: xr.DataArray, + confidence: xr.DataArray, + threshold: float = 0.6, print_report: bool = True, -) -> Union[xr.Dataset, None]: - """Fill in NaN values by interpolating over the time dimension. +) -> xr.DataArray: + """Drop data points below a certain confidence threshold. + + Data points with an associated confidence value below the threshold are + converted to NaN. Parameters ---------- - ds : xarray.Dataset - Dataset containing position, confidence scores, and metadata. - method : str - String indicating which method to use for interpolation. - Default is ``linear``. See documentation for - ``xarray.DataArray.interpolate_na`` for complete list of options. - max_gap : - The largest time gap of consecutive NaNs (in seconds) that will be - interpolated over. The default value is ``None`` (no limit). + data : xarray.DataArray + The input data to be filtered. + confidence : xarray.DataArray + The data array containing confidence scores to filter by. + threshold : float + The confidence threshold below which datapoints are filtered. + A default value of ``0.6`` is used. See notes for more information. print_report : bool Whether to print a report on the number of NaNs in the dataset - before and after interpolation. Default is ``True``. + before and after filtering. Default is ``True``. Returns ------- - ds_interpolated : xr.Dataset - The provided dataset (ds), where NaN values have been - interpolated over using the parameters provided. + xarray.DataArray + The data where points with a confidence value below the + user-defined threshold have been converted to NaNs. + + Notes + ----- + For the poses dataset case, note that the point-wise confidence values + reported by various pose estimation frameworks are not standardised, and + the range of values can vary. For example, DeepLabCut reports a likelihood + value between 0 and 1, whereas the point confidence reported by SLEAP can + range above 1. Therefore, the default threshold value will not be + appropriate for all datasets and does not have the same meaning across + pose estimation frameworks. We advise users to inspect the confidence + values in their dataset and adjust the threshold accordingly. """ - ds_interpolated = ds.copy() - position_interpolated = ds.position.interpolate_na( - dim="time", method=method, max_gap=max_gap, fill_value="extrapolate" - ) - ds_interpolated.update({"position": position_interpolated}) + data_filtered = data.where(confidence >= threshold) if print_report: - report_nan_values(ds, "input dataset") - report_nan_values(ds_interpolated, "interpolated dataset") - return ds_interpolated + print(report_nan_values(data, "input")) + print(report_nan_values(data_filtered, "output")) + return data_filtered @log_to_attrs -def filter_by_confidence( - ds: xr.Dataset, - threshold: float = 0.6, +def interpolate_over_time( + data: xr.DataArray, + method: str = "linear", + max_gap: int | None = None, print_report: bool = True, -) -> Union[xr.Dataset, None]: - """Drop all points below a certain confidence threshold. +) -> xr.DataArray: + """Fill in NaN values by interpolating over the ``time`` dimension. - Position points with an associated confidence value below the threshold are - converted to NaN. + This method uses :meth:`xarray.DataArray.interpolate_na` under the + hood and passes the ``method`` and ``max_gap`` parameters to it. + See the xarray documentation for more details on these parameters. Parameters ---------- - ds : xarray.Dataset - Dataset containing position, confidence scores, and metadata. - threshold : float - The confidence threshold below which datapoints are filtered. - A default value of ``0.6`` is used. See notes for more information. + data : xarray.DataArray + The input data to be interpolated. + method : str + String indicating which method to use for interpolation. + Default is ``linear``. + max_gap : int, optional + Maximum size of gap, a continuous sequence of missing observations + (represented as NaNs), to fill. + The default value is ``None`` (no limit). + Gap size is defined as the number of consecutive NaNs. print_report : bool Whether to print a report on the number of NaNs in the dataset - before and after filtering. Default is ``True``. + before and after interpolation. Default is ``True``. Returns ------- - ds_thresholded : xarray.Dataset - The provided dataset (ds), where points with a confidence - value below the user-defined threshold have been converted - to NaNs. + xarray.DataArray + The data where NaN values have been interpolated over + using the parameters provided. Notes ----- - The point-wise confidence values reported by various pose estimation - frameworks are not standardised, and the range of values can vary. - For example, DeepLabCut reports a likelihood value between 0 and 1, whereas - the point confidence reported by SLEAP can range above 1. - Therefore, the default threshold value will not be appropriate for all - datasets and does not have the same meaning across pose estimation - frameworks. We advise users to inspect the confidence values - in their dataset and adjust the threshold accordingly. + The ``max_gap`` parameter differs slightly from that in + :meth:`xarray.DataArray.interpolate_na`, in which the gap size + is defined as the difference between the ``time`` coordinate values + at the first data point after a gap and the last value before a gap. """ - ds_thresholded = ds.copy() - ds_thresholded.update( - {"position": ds.position.where(ds.confidence >= threshold)} + data_interpolated = data.interpolate_na( + dim="time", + method=method, + use_coordinate=False, + max_gap=max_gap + 1 if max_gap is not None else None, + fill_value="extrapolate", ) if print_report: - report_nan_values(ds, "input dataset") - report_nan_values(ds_thresholded, "filtered dataset") - - return ds_thresholded + print(report_nan_values(data, "input")) + print(report_nan_values(data_interpolated, "output")) + return data_interpolated @log_to_attrs def median_filter( - ds: xr.Dataset, - window_length: int, - min_periods: Optional[int] = None, + data: xr.DataArray, + window: int, + min_periods: int | None = None, print_report: bool = True, -) -> xr.Dataset: - """Smooth pose tracks by applying a median filter over time. +) -> xr.DataArray: + """Smooth data by applying a median filter over time. Parameters ---------- - ds : xarray.Dataset - Dataset containing position, confidence scores, and metadata. - window_length : int - The size of the filter window. Window length is interpreted - as being in the input dataset's time unit, which can be inspected - with ``ds.time_unit``. + data : xarray.DataArray + The input data to be smoothed. + window : int + The size of the smoothing window, representing the fixed number + of observations used for each window. min_periods : int - Minimum number of observations in window required to have a value - (otherwise result is NaN). The default, None, is equivalent to - setting ``min_periods`` equal to the size of the window. + Minimum number of observations in the window required to have + a value (otherwise result is NaN). The default, None, is + equivalent to setting ``min_periods`` equal to the size of the window. This argument is directly passed to the ``min_periods`` parameter of - ``xarray.DataArray.rolling``. + :meth:`xarray.DataArray.rolling`. print_report : bool Whether to print a report on the number of NaNs in the dataset - before and after filtering. Default is ``True``. + before and after smoothing. Default is ``True``. Returns ------- - ds_smoothed : xarray.Dataset - The provided dataset (ds), where pose tracks have been smoothed - using a median filter with the provided parameters. + xarray.DataArray + The data smoothed using a median filter with the provided parameters. Notes ----- - By default, whenever one or more NaNs are present in the filter window, + By default, whenever one or more NaNs are present in the smoothing window, a NaN is returned to the output array. As a result, any - stretch of NaNs present in the input dataset will be propagated - proportionally to the size of the window in frames (specifically, by - ``floor(window_length/2)``). To control this behaviour, the + stretch of NaNs present in the input data will be propagated + proportionally to the size of the window (specifically, by + ``floor(window/2)``). To control this behaviour, the ``min_periods`` option can be used to specify the minimum number of non-NaN values required in the window to compute a result. For example, setting ``min_periods=1`` will result in the filter returning NaNs @@ -216,85 +158,74 @@ def median_filter( is sufficient to compute the median. """ - ds_smoothed = ds.copy() - - # Express window length (and its half) in frames - if ds.time_unit == "seconds": - window_length = int(window_length * ds.fps) - - half_window = window_length // 2 - - ds_smoothed.update( - { - "position": ds.position.pad( # Pad the edges to avoid NaNs - time=half_window, mode="reflect" - ) - .rolling( # Take rolling windows across time - time=window_length, center=True, min_periods=min_periods - ) - .median( # Compute the median of each window - skipna=True - ) - .isel( # Remove the padded edges - time=slice(half_window, -half_window) - ) - } + half_window = window // 2 + data_smoothed = ( + data.pad( # Pad the edges to avoid NaNs + time=half_window, mode="reflect" + ) + .rolling( # Take rolling windows across time + time=window, center=True, min_periods=min_periods + ) + .median( # Compute the median of each window + skipna=True + ) + .isel( # Remove the padded edges + time=slice(half_window, -half_window) + ) ) - if print_report: - report_nan_values(ds, "input dataset") - report_nan_values(ds_smoothed, "filtered dataset") - - return ds_smoothed + print(report_nan_values(data, "input")) + print(report_nan_values(data_smoothed, "output")) + return data_smoothed @log_to_attrs def savgol_filter( - ds: xr.Dataset, - window_length: int, + data: xr.DataArray, + window: int, polyorder: int = 2, print_report: bool = True, **kwargs, -) -> xr.Dataset: - """Smooth pose tracks by applying a Savitzky-Golay filter over time. +) -> xr.DataArray: + """Smooth data by applying a Savitzky-Golay filter over time. Parameters ---------- - ds : xarray.Dataset - Dataset containing position, confidence scores, and metadata. - window_length : int - The size of the filter window. Window length is interpreted - as being in the input dataset's time unit, which can be inspected - with ``ds.time_unit``. + data : xarray.DataArray + The input data to be smoothed. + window : int + The size of the smoothing window, representing the fixed number + of observations used for each window. polyorder : int The order of the polynomial used to fit the samples. Must be - less than ``window_length``. By default, a ``polyorder`` of + less than ``window``. By default, a ``polyorder`` of 2 is used. print_report : bool Whether to print a report on the number of NaNs in the dataset - before and after filtering. Default is ``True``. + before and after smoothing. Default is ``True``. **kwargs : dict - Additional keyword arguments are passed to scipy.signal.savgol_filter. + Additional keyword arguments are passed to + :func:`scipy.signal.savgol_filter`. Note that the ``axis`` keyword argument may not be overridden. Returns ------- - ds_smoothed : xarray.Dataset - The provided dataset (ds), where pose tracks have been smoothed - using a Savitzky-Golay filter with the provided parameters. + xarray.DataArray + The data smoothed using a Savitzky-Golay filter with the + provided parameters. Notes ----- - Uses the ``scipy.signal.savgol_filter`` function to apply a Savitzky-Golay - filter to the input dataset's ``position`` variable. - See the scipy documentation for more information on that function. - Whenever one or more NaNs are present in a filter window of the - input dataset, a NaN is returned to the output array. As a result, any - stretch of NaNs present in the input dataset will be propagated - proportionally to the size of the window in frames (specifically, by - ``floor(window_length/2)``). Note that, unlike - ``movement.filtering.median_filter()``, there is no ``min_periods`` + Uses the :func:`scipy.signal.savgol_filter` function to apply a + Savitzky-Golay filter to the input data. + See the SciPy documentation for more information on that function. + Whenever one or more NaNs are present in a smoothing window of the + input data, a NaN is returned to the output array. As a result, any + stretch of NaNs present in the input data will be propagated + proportionally to the size of the window (specifically, by + ``floor(window/2)``). Note that, unlike + :func:`movement.filtering.median_filter`, there is no ``min_periods`` option to control this behaviour. """ @@ -302,25 +233,15 @@ def savgol_filter( raise log_error( ValueError, "The 'axis' argument may not be overridden." ) - - ds_smoothed = ds.copy() - - if ds.time_unit == "seconds": - window_length = int(window_length * ds.fps) - - position_smoothed = signal.savgol_filter( - ds.position, - window_length, + data_smoothed = data.copy() + data_smoothed.values = signal.savgol_filter( + data, + window, polyorder, axis=0, **kwargs, ) - position_smoothed_da = ds.position.copy(data=position_smoothed) - - ds_smoothed.update({"position": position_smoothed_da}) - if print_report: - report_nan_values(ds, "input dataset") - report_nan_values(ds_smoothed, "filtered dataset") - - return ds_smoothed + print(report_nan_values(data, "input")) + print(report_nan_values(data_smoothed, "output")) + return data_smoothed diff --git a/movement/io/load_bboxes.py b/movement/io/load_bboxes.py new file mode 100644 index 00000000..3e1b0e0d --- /dev/null +++ b/movement/io/load_bboxes.py @@ -0,0 +1,655 @@ +"""Load bounding boxes' tracking data into ``movement``.""" + +import ast +import logging +import re +from collections.abc import Callable +from pathlib import Path +from typing import Literal + +import numpy as np +import pandas as pd +import xarray as xr + +from movement.utils.logging import log_error +from movement.validators.datasets import ValidBboxesDataset +from movement.validators.files import ValidFile, ValidVIATracksCSV + +logger = logging.getLogger(__name__) + + +def from_numpy( + position_array: np.ndarray, + shape_array: np.ndarray, + confidence_array: np.ndarray | None = None, + individual_names: list[str] | None = None, + frame_array: np.ndarray | None = None, + fps: float | None = None, + source_software: str | None = None, +) -> xr.Dataset: + """Create a ``movement`` bounding boxes dataset from NumPy arrays. + + Parameters + ---------- + position_array : np.ndarray + Array of shape (n_frames, n_individuals, n_space) + containing the tracks of the bounding boxes' centroids. + It will be converted to a :class:`xarray.DataArray` object + named "position". + shape_array : np.ndarray + Array of shape (n_frames, n_individuals, n_space) + containing the shape of the bounding boxes. The shape of a bounding + box is its width (extent along the x-axis of the image) and height + (extent along the y-axis of the image). It will be converted to a + :class:`xarray.DataArray` object named "shape". + confidence_array : np.ndarray, optional + Array of shape (n_frames, n_individuals) containing + the confidence scores of the bounding boxes. If None (default), the + confidence scores are set to an array of NaNs. It will be converted + to a :class:`xarray.DataArray` object named "confidence". + individual_names : list of str, optional + List of individual names for the tracked bounding boxes in the video. + If None (default), bounding boxes are assigned names based on the size + of the ``position_array``. The names will be in the format of + ``id_``, where is an integer from 0 to + ``position_array.shape[1]-1`` (i.e., "id_0", "id_1"...). + frame_array : np.ndarray, optional + Array of shape (n_frames, 1) containing the frame numbers for which + bounding boxes are defined. If None (default), frame numbers will + be assigned based on the first dimension of the ``position_array``, + starting from 0. If a specific array of frame numbers is provided, + these need to be consecutive integers. + fps : float, optional + The video sampling rate. If None (default), the ``time`` coordinates + of the resulting ``movement`` dataset will be in frame numbers. If + ``fps`` is provided, the ``time`` coordinates will be in seconds. If + the ``time`` coordinates are in seconds, they will indicate the + elapsed time from the capture of the first frame (assumed to be frame + 0). + source_software : str, optional + Name of the software that generated the data. Defaults to None. + + Returns + ------- + xarray.Dataset + ``movement`` dataset containing the position, shape, and confidence + scores of the tracked bounding boxes, and any associated metadata. + + Examples + -------- + Create random position data for two bounding boxes, ``id_0`` and ``id_1``, + with the same width (40 pixels) and height (30 pixels). These are tracked + in 2D space for 100 frames, which are numbered from the start frame 1200 + to the end frame 1299. The confidence score for all bounding boxes is set + to 0.5. + + >>> import numpy as np + >>> from movement.io import load_bboxes + >>> ds = load_bboxes.from_numpy( + ... position_array=np.random.rand(100, 2, 2), + ... shape_array=np.ones((100, 2, 2)) * [40, 30], + ... confidence_array=np.ones((100, 2)) * 0.5, + ... individual_names=["id_0", "id_1"], + ... frame_array=np.arange(1200, 1300).reshape(-1, 1), + ... ) + + Create a dataset with the same data as above, but with the time + coordinates in seconds. We use a video sampling rate of 60 fps. The time + coordinates in the resulting dataset will indicate the elapsed time from + the capture of the 0th frame. So for the frames 1200, 1201, 1203,... 1299 + the corresponding time coordinates in seconds will be 20, 20.0167, + 20.033,... 21.65 s. + + >>> ds = load_bboxes.from_numpy( + ... position_array=np.random.rand(100, 2, 2), + ... shape_array=np.ones((100, 2, 2)) * [40, 30], + ... confidence_array=np.ones((100, 2)) * 0.5, + ... individual_names=["id_0", "id_1"], + ... frame_array=np.arange(1200, 1300).reshape(-1, 1), + ... fps=60, + ... ) + + Create a dataset with the same data as above, but express the time + coordinate in frames, and assume the first tracked frame is frame 0. + To do this, we simply omit the ``frame_array`` input argument. + + >>> ds = load_bboxes.from_numpy( + ... position_array=np.random.rand(100, 2, 2), + ... shape_array=np.ones((100, 2, 2)) * [40, 30], + ... confidence_array=np.ones((100, 2)) * 0.5, + ... individual_names=["id_0", "id_1"], + ... ) + + Create a dataset with the same data as above, but express the time + coordinate in seconds, and assume the first tracked frame is captured + at time = 0 seconds. To do this, we omit the ``frame_array`` input argument + and pass an ``fps`` value. + + >>> ds = load_bboxes.from_numpy( + ... position_array=np.random.rand(100, 2, 2), + ... shape_array=np.ones((100, 2, 2)) * [40, 30], + ... confidence_array=np.ones((100, 2)) * 0.5, + ... individual_names=["id_0", "id_1"], + ... fps=60, + ... ) + + """ + valid_bboxes_data = ValidBboxesDataset( + position_array=position_array, + shape_array=shape_array, + confidence_array=confidence_array, + individual_names=individual_names, + frame_array=frame_array, + fps=fps, + source_software=source_software, + ) + return _ds_from_valid_data(valid_bboxes_data) + + +def from_file( + file_path: Path | str, + source_software: Literal["VIA-tracks"], + fps: float | None = None, + use_frame_numbers_from_file: bool = False, +) -> xr.Dataset: + """Create a ``movement`` bounding boxes dataset from a supported file. + + At the moment, we only support VIA-tracks .csv files. + + Parameters + ---------- + file_path : pathlib.Path or str + Path to the file containing the tracked bounding boxes. Currently + only VIA-tracks .csv files are supported. + source_software : "VIA-tracks". + The source software of the file. Currently only files from the + VIA 2.0.12 annotator [1]_ ("VIA-tracks") are supported. + See . + fps : float, optional + The video sampling rate. If None (default), the ``time`` coordinates + of the resulting ``movement`` dataset will be in frame numbers. If + ``fps`` is provided, the ``time`` coordinates will be in seconds. If + the ``time`` coordinates are in seconds, they will indicate the + elapsed time from the capture of the first frame (assumed to be frame + 0). + use_frame_numbers_from_file : bool, optional + If True, the frame numbers in the resulting dataset are + the same as the ones specified for each tracked bounding box in the + input file. This may be useful if the bounding boxes are tracked for a + subset of frames in a video, but you want to maintain the start of the + full video as the time origin. If False (default), the frame numbers + in the VIA tracks .csv file are instead mapped to a 0-based sequence of + consecutive integers. + + Returns + ------- + xarray.Dataset + ``movement`` dataset containing the position, shape, and confidence + scores of the tracked bounding boxes, and any associated metadata. + + See Also + -------- + movement.io.load_bboxes.from_via_tracks_file + + References + ---------- + .. [1] https://www.robots.ox.ac.uk/~vgg/software/via/ + + Examples + -------- + Create a dataset from the VIA tracks .csv file at "path/to/file.csv", with + the time coordinates in seconds, and assuming t = 0 seconds corresponds to + the first tracked frame in the file. + + >>> from movement.io import load_bboxes + >>> ds = load_bboxes.from_file( + >>> "path/to/file.csv", + >>> source_software="VIA-tracks", + >>> fps=30, + >>> ) + + """ + if source_software == "VIA-tracks": + return from_via_tracks_file( + file_path, + fps, + use_frame_numbers_from_file=use_frame_numbers_from_file, + ) + else: + raise log_error( + ValueError, f"Unsupported source software: {source_software}" + ) + + +def from_via_tracks_file( + file_path: Path | str, + fps: float | None = None, + use_frame_numbers_from_file: bool = False, +) -> xr.Dataset: + """Create a ``movement`` dataset from a VIA tracks .csv file. + + Parameters + ---------- + file_path : pathlib.Path or str + Path to the VIA tracks .csv file with the tracked bounding boxes. + For more information on the VIA tracks .csv file format, see the VIA + tutorial for tracking [1]_. + fps : float, optional + The video sampling rate. If None (default), the ``time`` coordinates + of the resulting ``movement`` dataset will be in frame numbers. If + ``fps`` is provided, the ``time`` coordinates will be in seconds. If + the ``time`` coordinates are in seconds, they will indicate the + elapsed time from the capture of the first frame (assumed to be frame + 0). + use_frame_numbers_from_file : bool, optional + If True, the frame numbers in the resulting dataset are + the same as the ones in the VIA tracks .csv file. This may be useful if + the bounding boxes are tracked for a subset of frames in a video, + but you want to maintain the start of the full video as the time + origin. If False (default), the frame numbers in the VIA tracks .csv + file are instead mapped to a 0-based sequence of consecutive integers. + + Returns + ------- + xarray.Dataset + ``movement`` dataset containing the position, shape, and confidence + scores of the tracked bounding boxes, and any associated metadata. + + Notes + ----- + The bounding boxes' IDs specified in the "track" field of the VIA + tracks .csv file are mapped to the "individual_name" column of the + ``movement`` dataset. The individual names follow the format ``id_``, + with N being the bounding box ID. + + References + ---------- + .. [1] https://www.robots.ox.ac.uk/~vgg/software/via/docs/face_track_annotation.html + + Examples + -------- + Create a dataset from the VIA tracks .csv file at "path/to/file.csv", with + the time coordinates in frames, and setting the first tracked frame in the + file as frame 0. + + >>> from movement.io import load_bboxes + >>> ds = load_bboxes.from_via_tracks_file( + ... "path/to/file.csv", + ... ) + + Create a dataset from the VIA tracks .csv file at "path/to/file.csv", with + the time coordinates in seconds, and assuming t = 0 seconds corresponds to + the first tracked frame in the file. + + >>> from movement.io import load_bboxes + >>> ds = load_bboxes.from_via_tracks_file( + ... "path/to/file.csv", + ... fps=30, + ... ) + + Create a dataset from the VIA tracks .csv file at "path/to/file.csv", with + the time coordinates in frames, and using the same frame numbers as + in the VIA tracks .csv file. + + >>> from movement.io import load_bboxes + >>> ds = load_bboxes.from_via_tracks_file( + ... "path/to/file.csv", + ... use_frame_numbers_from_file=True. + ... ) + + Create a dataset from the VIA tracks .csv file at "path/to/file.csv", with + the time coordinates in seconds, and assuming t = 0 seconds corresponds to + the 0th frame in the full video. + + >>> from movement.io import load_bboxes + >>> ds = load_bboxes.from_via_tracks_file( + ... "path/to/file.csv", + ... fps=30, + ... use_frame_numbers_from_file=True, + ... ) + + + """ + # General file validation + file = ValidFile( + file_path, expected_permission="r", expected_suffix=[".csv"] + ) + + # Specific VIA-tracks .csv file validation + via_file = ValidVIATracksCSV(file.path) + logger.debug(f"Validated VIA tracks .csv file {via_file.path}.") + + # Create an xarray.Dataset from the data + bboxes_arrays = _numpy_arrays_from_via_tracks_file(via_file.path) + ds = from_numpy( + position_array=bboxes_arrays["position_array"], + shape_array=bboxes_arrays["shape_array"], + confidence_array=bboxes_arrays["confidence_array"], + individual_names=[ + f"id_{id}" for id in bboxes_arrays["ID_array"].squeeze() + ], + frame_array=( + bboxes_arrays["frame_array"] + if use_frame_numbers_from_file + else None + ), + fps=fps, + source_software="VIA-tracks", + ) # it validates the dataset via ValidBboxesDataset + + # Add metadata as attributes + ds.attrs["source_software"] = "VIA-tracks" + ds.attrs["source_file"] = file.path.as_posix() + + logger.info(f"Loaded tracks of the bounding boxes from {via_file.path}:") + logger.info(ds) + return ds + + +def _numpy_arrays_from_via_tracks_file(file_path: Path) -> dict: + """Extract numpy arrays from the input VIA tracks .csv file. + + The extracted numpy arrays are returned in a dictionary with the following + keys: + + - position_array (n_frames, n_individuals, n_space): + contains the trajectories of the bounding boxes' centroids. + - shape_array (n_frames, n_individuals, n_space): + contains the shape of the bounding boxes (width and height). + - confidence_array (n_frames, n_individuals): + contains the confidence score of each bounding box. + If no confidence scores are provided, they are set to an array of NaNs. + - ID_array (n_individuals, 1): + contains the integer IDs of the tracked bounding boxes. + - frame_array (n_frames, 1): + contains the frame numbers. + + Parameters + ---------- + file_path : pathlib.Path + Path to the VIA tracks .csv file containing the bounding boxes' tracks. + + Returns + ------- + dict + The validated bounding boxes' arrays. + + """ + # Extract 2D dataframe from input data + # (sort data by ID and frame number, and + # fill empty frame-ID pairs with nans) + df = _df_from_via_tracks_file(file_path) + + # Compute indices of the rows where the IDs switch + bool_id_diff_from_prev = df["ID"].ne(df["ID"].shift()) # pandas series + indices_id_switch = ( + bool_id_diff_from_prev.loc[lambda x: x].index[1:].to_numpy() + ) + + # Stack position, shape and confidence arrays along ID axis + map_key_to_columns = { + "position_array": ["x", "y"], + "shape_array": ["w", "h"], + "confidence_array": ["confidence"], + } + array_dict = {} + for key in map_key_to_columns: + list_arrays = np.split( + df[map_key_to_columns[key]].to_numpy(), + indices_id_switch, # indices along axis=0 + ) + + array_dict[key] = np.stack(list_arrays, axis=1).squeeze() + + # Transform position_array to represent centroid of bbox, + # rather than top-left corner + # (top left corner: corner of the bbox with minimum x and y coordinates) + array_dict["position_array"] += array_dict["shape_array"] / 2 + + # Add remaining arrays to dict + array_dict["ID_array"] = df["ID"].unique().reshape(-1, 1) + array_dict["frame_array"] = df["frame_number"].unique().reshape(-1, 1) + + return array_dict + + +def _df_from_via_tracks_file(file_path: Path) -> pd.DataFrame: + """Load VIA tracks .csv file as a dataframe. + + Read the VIA tracks .csv file as a pandas dataframe with columns: + - ID: the integer ID of the tracked bounding box. + - frame_number: the frame number of the tracked bounding box. + - x: the x-coordinate of the tracked bounding box's top-left corner. + - y: the y-coordinate of the tracked bounding box's top-left corner. + - w: the width of the tracked bounding box. + - h: the height of the tracked bounding box. + - confidence: the confidence score of the tracked bounding box. + + The dataframe is sorted by ID and frame number, and for each ID, + empty frames are filled in with NaNs. The coordinates of the bboxes + are assumed to be in the image coordinate system (i.e., the top-left + corner of a bbox is its corner with minimum x and y coordinates). + """ + # Read VIA tracks .csv file as a pandas dataframe + df_file = pd.read_csv(file_path, sep=",", header=0) + + # Format to a 2D dataframe + df = pd.DataFrame( + { + "ID": _via_attribute_column_to_numpy( + df_file, "region_attributes", ["track"], int + ), + "frame_number": _extract_frame_number_from_via_tracks_df(df_file), + "x": _via_attribute_column_to_numpy( + df_file, "region_shape_attributes", ["x"], float + ), + "y": _via_attribute_column_to_numpy( + df_file, "region_shape_attributes", ["y"], float + ), + "w": _via_attribute_column_to_numpy( + df_file, "region_shape_attributes", ["width"], float + ), + "h": _via_attribute_column_to_numpy( + df_file, "region_shape_attributes", ["height"], float + ), + "confidence": _extract_confidence_from_via_tracks_df(df_file), + } + ) + + # Sort dataframe by ID and frame number + df = df.sort_values(by=["ID", "frame_number"]).reset_index(drop=True) + + # Fill in empty frames with nans + multi_index = pd.MultiIndex.from_product( + [df["ID"].unique(), df["frame_number"].unique()], + names=["ID", "frame_number"], + ) # desired index: all combinations of ID and frame number + + # Set index to (ID, frame number), fill in values with nans and + # reset to original index + df = ( + df.set_index(["ID", "frame_number"]).reindex(multi_index).reset_index() + ) + return df + + +def _extract_confidence_from_via_tracks_df(df) -> np.ndarray: + """Extract confidence scores from the VIA tracks input dataframe. + + Parameters + ---------- + df : pd.DataFrame + The VIA tracks input dataframe is the one obtained from + ``df = pd.read_csv(file_path, sep=",", header=0)``. + + Returns + ------- + np.ndarray + A numpy array of size (n_bboxes, ) containing the bounding boxes + confidence scores. + + """ + region_attributes_dicts = [ + ast.literal_eval(d) for d in df.region_attributes + ] + + # Check if confidence is defined as a region attribute, else set to NaN + if all(["confidence" in d for d in region_attributes_dicts]): + bbox_confidence = _via_attribute_column_to_numpy( + df, "region_attributes", ["confidence"], float + ) + else: + bbox_confidence = np.full((df.shape[0], 1), np.nan).squeeze() + + return bbox_confidence + + +def _extract_frame_number_from_via_tracks_df(df) -> np.ndarray: + """Extract frame numbers from the VIA tracks input dataframe. + + Parameters + ---------- + df : pd.DataFrame + The VIA tracks input dataframe is the one obtained from + ``df = pd.read_csv(file_path, sep=",", header=0)``. + + Returns + ------- + np.ndarray + A numpy array of size (n_frames, ) containing the frame numbers. + In the VIA tracks .csv file, the frame number is expected to be + defined as a 'file_attribute' , or encoded in the filename as an + integer number led by at least one zero, between "_" and ".", followed + by the file extension. + + """ + # Extract frame number from file_attributes if exists + file_attributes_dicts = [ast.literal_eval(d) for d in df.file_attributes] + if all(["frame" in d for d in file_attributes_dicts]): + frame_array = _via_attribute_column_to_numpy( + df, + via_column_name="file_attributes", + list_keys=["frame"], + cast_fn=int, + ) + # Else extract from filename + else: + pattern = r"_(0\d*)\.\w+$" + list_frame_numbers = [ + int(re.search(pattern, f).group(1)) # type: ignore + if re.search(pattern, f) + else np.nan + for f in df["filename"] + ] + + frame_array = np.array(list_frame_numbers) + + return frame_array + + +def _via_attribute_column_to_numpy( + df: pd.DataFrame, + via_column_name: str, + list_keys: list[str], + cast_fn: Callable = float, +) -> np.ndarray: + """Convert values from VIA attribute-type column to a numpy array. + + In the VIA tracks .csv file, the attribute-type columns are the columns + whose name includes the word ``attributes`` (i.e. ``file_attributes``, + ``region_shape_attributes`` or ``region_attributes``). These columns hold + dictionary data. + + Parameters + ---------- + df : pd.DataFrame + The pandas DataFrame containing the data from the VIA tracks .csv file. + This is the dataframe obtained from running + ``df = pd.read_csv(file_path, sep=",", header=0)``. + via_column_name : str + The name of a column in the VIA tracks .csv file whose values are + literal dictionaries (i.e. ``file_attributes``, + ``region_shape_attributes`` or ``region_attributes``). + list_keys : list[str] + The list of keys whose values we want to extract from the literal + dictionaries in the ``via_column_name`` column. + cast_fn : type, optional + The type function to cast the values to. By default ``float``. + + Returns + ------- + np.ndarray + A numpy array holding the extracted values. If ``len(list_keys) > 1`` + the array is two-dimensional with shape ``(N, len(list_keys))``, where + ``N`` is the number of rows in the input dataframe ``df``. If + ``len(list_keys) == 1``, the resulting array will be one-dimensional, + with shape (N, ). Note that the computed array is squeezed before + returning. + + """ + list_bbox_attr = [] + for _, row in df.iterrows(): + row_dict_data = ast.literal_eval(row[via_column_name]) + list_bbox_attr.append( + tuple(cast_fn(row_dict_data[reg]) for reg in list_keys) + ) + + bbox_attr_array = np.array(list_bbox_attr) + + return bbox_attr_array.squeeze() + + +def _ds_from_valid_data(data: ValidBboxesDataset) -> xr.Dataset: + """Convert a validated bounding boxes dataset to an xarray Dataset. + + Parameters + ---------- + data : movement.validators.datasets.ValidBboxesDataset + The validated bounding boxes dataset object. + + Returns + ------- + bounding boxes dataset containing the boxes tracks, + boxes shapes, confidence scores and associated metadata. + + """ + # Create the time coordinate + time_coords = data.frame_array.squeeze() # type: ignore + time_unit = "frames" + # if fps is provided: + # time_coords is expressed in seconds, with the time origin + # set as frame 0 == time 0 seconds + if data.fps: + # Compute elapsed time from frame 0. + # Ignoring type error because `data.frame_array` is not None after + # ValidBboxesDataset.__attrs_post_init__() # type: ignore + time_coords = np.array( + [frame / data.fps for frame in data.frame_array.squeeze()] # type: ignore + ) + time_unit = "seconds" + + # Convert data to an xarray.Dataset + # with dimensions ('time', 'individuals', 'space') + DIM_NAMES = ValidBboxesDataset.DIM_NAMES + n_space = data.position_array.shape[-1] + return xr.Dataset( + data_vars={ + "position": xr.DataArray(data.position_array, dims=DIM_NAMES), + "shape": xr.DataArray(data.shape_array, dims=DIM_NAMES), + "confidence": xr.DataArray( + data.confidence_array, dims=DIM_NAMES[:-1] + ), + }, + coords={ + DIM_NAMES[0]: time_coords, + DIM_NAMES[1]: data.individual_names, + DIM_NAMES[2]: ["x", "y", "z"][:n_space], + }, + attrs={ + "fps": data.fps, + "time_unit": time_unit, + "source_software": data.source_software, + "source_file": None, + "ds_type": "bboxes", + }, + ) diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index fd1d280a..de259aa7 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -1,8 +1,8 @@ -"""Functions for loading pose tracking data from various frameworks.""" +"""Load pose tracking data from various frameworks into ``movement``.""" import logging from pathlib import Path -from typing import Literal, Optional, Union +from typing import Literal import h5py import numpy as np @@ -11,35 +11,98 @@ from sleap_io.io.slp import read_labels from sleap_io.model.labels import Labels -from movement import MovementDataset -from movement.io.validators import ( - ValidDeepLabCutCSV, - ValidFile, - ValidHDF5, - ValidPosesDataset, -) -from movement.logging import log_error, log_warning +from movement.utils.logging import log_error, log_warning +from movement.validators.datasets import ValidPosesDataset +from movement.validators.files import ValidDeepLabCutCSV, ValidFile, ValidHDF5 logger = logging.getLogger(__name__) +def from_numpy( + position_array: np.ndarray, + confidence_array: np.ndarray | None = None, + individual_names: list[str] | None = None, + keypoint_names: list[str] | None = None, + fps: float | None = None, + source_software: str | None = None, +) -> xr.Dataset: + """Create a ``movement`` poses dataset from NumPy arrays. + + Parameters + ---------- + position_array : np.ndarray + Array of shape (n_frames, n_individuals, n_keypoints, n_space) + containing the poses. It will be converted to a + :class:`xarray.DataArray` object named "position". + confidence_array : np.ndarray, optional + Array of shape (n_frames, n_individuals, n_keypoints) containing + the point-wise confidence scores. It will be converted to a + :class:`xarray.DataArray` object named "confidence". + If None (default), the scores will be set to an array of NaNs. + individual_names : list of str, optional + List of unique names for the individuals in the video. If None + (default), the individuals will be named "individual_0", + "individual_1", etc. + keypoint_names : list of str, optional + List of unique names for the keypoints in the skeleton. If None + (default), the keypoints will be named "keypoint_0", "keypoint_1", + etc. + fps : float, optional + Frames per second of the video. Defaults to None, in which case + the time coordinates will be in frame numbers. + source_software : str, optional + Name of the pose estimation software from which the data originate. + Defaults to None. + + Returns + ------- + xarray.Dataset + ``movement`` dataset containing the pose tracks, confidence scores, + and associated metadata. + + Examples + -------- + Create random position data for two individuals, ``Alice`` and ``Bob``, + with three keypoints each: ``snout``, ``centre``, and ``tail_base``. + These are tracked in 2D space over 100 frames, at 30 fps. + The confidence scores are set to 1 for all points. + + >>> import numpy as np + >>> from movement.io import load_poses + >>> ds = load_poses.from_numpy( + ... position_array=np.random.rand((100, 2, 3, 2)), + ... confidence_array=np.ones((100, 2, 3)), + ... individual_names=["Alice", "Bob"], + ... keypoint_names=["snout", "centre", "tail_base"], + ... fps=30, + ... ) + + """ + valid_data = ValidPosesDataset( + position_array=position_array, + confidence_array=confidence_array, + individual_names=individual_names, + keypoint_names=keypoint_names, + fps=fps, + source_software=source_software, + ) + return _ds_from_valid_data(valid_data) + + def from_file( - file_path: Union[Path, str], + file_path: Path | str, source_software: Literal["DeepLabCut", "SLEAP", "LightningPose"], - fps: Optional[float] = None, + fps: float | None = None, ) -> xr.Dataset: - """Load pose tracking data from any supported file format. - - Data can be loaded from a DeepLabCut (DLC), LightningPose (LP) or - SLEAP output file into an xarray Dataset. + """Create a ``movement`` poses dataset from any supported file. Parameters ---------- file_path : pathlib.Path or str Path to the file containing predicted poses. The file format must be among those supported by the ``from_dlc_file()``, - ``from_slp_file()`` or ``from_lp_file()`` functions, - since one of these functions will be called internally, based on + ``from_slp_file()`` or ``from_lp_file()`` functions. One of these + these functions will be called internally, based on the value of ``source_software``. source_software : "DeepLabCut", "SLEAP" or "LightningPose" The source software of the file. @@ -50,7 +113,8 @@ def from_file( Returns ------- xarray.Dataset - Dataset containing the pose tracks, confidence scores, and metadata. + ``movement`` dataset containing the pose tracks, confidence scores, + and associated metadata. See Also -------- @@ -58,6 +122,13 @@ def from_file( movement.io.load_poses.from_sleap_file movement.io.load_poses.from_lp_file + Examples + -------- + >>> from movement.io import load_poses + >>> ds = load_poses.from_file( + ... "path/to/file.h5", source_software="DeepLabCut", fps=30 + ... ) + """ if source_software == "DeepLabCut": return from_dlc_file(file_path, fps) @@ -71,8 +142,12 @@ def from_file( ) -def from_dlc_df(df: pd.DataFrame, fps: Optional[float] = None) -> xr.Dataset: - """Create an xarray.Dataset from a DeepLabCut-style pandas DataFrame. +def from_dlc_style_df( + df: pd.DataFrame, + fps: float | None = None, + source_software: Literal["DeepLabCut", "LightningPose"] = "DeepLabCut", +) -> xr.Dataset: + """Create a ``movement`` poses dataset from a DeepLabCut-style DataFrame. Parameters ---------- @@ -82,11 +157,16 @@ def from_dlc_df(df: pd.DataFrame, fps: Optional[float] = None) -> xr.Dataset: fps : float, optional The number of frames per second in the video. If None (default), the `time` coordinates will be in frame numbers. + source_software : str, optional + Name of the pose estimation software from which the data originate. + Defaults to "DeepLabCut", but it can also be "LightningPose" + (because they the same DataFrame format). Returns ------- xarray.Dataset - Dataset containing the pose tracks, confidence scores, and metadata. + ``movement`` dataset containing the pose tracks, confidence scores, + and associated metadata. Notes ----- @@ -99,7 +179,7 @@ def from_dlc_df(df: pd.DataFrame, fps: Optional[float] = None) -> xr.Dataset: See Also -------- - movement.io.load_poses.from_dlc_file : Load pose tracks directly from file. + movement.io.load_poses.from_dlc_file """ # read names of individuals and keypoints from the DataFrame @@ -120,20 +200,20 @@ def from_dlc_df(df: pd.DataFrame, fps: Optional[float] = None) -> xr.Dataset: (-1, len(individual_names), len(keypoint_names), 3) ) - valid_data = ValidPosesDataset( + return from_numpy( position_array=tracks_with_scores[:, :, :, :-1], confidence_array=tracks_with_scores[:, :, :, -1], individual_names=individual_names, keypoint_names=keypoint_names, fps=fps, + source_software=source_software, ) - return _from_valid_data(valid_data) def from_sleap_file( - file_path: Union[Path, str], fps: Optional[float] = None + file_path: Path | str, fps: float | None = None ) -> xr.Dataset: - """Load pose tracking data from a SLEAP file into an xarray Dataset. + """Create a ``movement`` poses dataset from a SLEAP file. Parameters ---------- @@ -148,7 +228,8 @@ def from_sleap_file( Returns ------- xarray.Dataset - Dataset containing the pose tracks, confidence scores, and metadata. + ``movement`` dataset containing the pose tracks, confidence scores, + and associated metadata. Notes ----- @@ -193,16 +274,11 @@ def from_sleap_file( # Load and validate data if file.path.suffix == ".h5": - valid_data = _load_from_sleap_analysis_file(file.path, fps=fps) + ds = _ds_from_sleap_analysis_file(file.path, fps=fps) else: # file.path.suffix == ".slp" - valid_data = _load_from_sleap_labels_file(file.path, fps=fps) - logger.debug(f"Validated pose tracks from {file.path}.") - - # Initialize an xarray dataset from the dictionary - ds = _from_valid_data(valid_data) + ds = _ds_from_sleap_labels_file(file.path, fps=fps) # Add metadata as attrs - ds.attrs["source_software"] = "SLEAP" ds.attrs["source_file"] = file.path.as_posix() logger.info(f"Loaded pose tracks from {file.path}:") @@ -211,14 +287,14 @@ def from_sleap_file( def from_lp_file( - file_path: Union[Path, str], fps: Optional[float] = None + file_path: Path | str, fps: float | None = None ) -> xr.Dataset: - """Load pose tracking data from a LightningPose (LP) output file. + """Create a ``movement`` poses dataset from a LightningPose file. Parameters ---------- file_path : pathlib.Path or str - Path to the file containing the LP predicted poses, in .csv format. + Path to the file containing the predicted poses, in .csv format. fps : float, optional The number of frames per second in the video. If None (default), the `time` coordinates will be in frame numbers. @@ -226,7 +302,8 @@ def from_lp_file( Returns ------- xarray.Dataset - Dataset containing the pose tracks, confidence scores, and metadata. + ``movement`` dataset containing the pose tracks, confidence scores, + and associated metadata. Examples -------- @@ -234,20 +311,20 @@ def from_lp_file( >>> ds = load_poses.from_lp_file("path/to/file.csv", fps=30) """ - return _from_lp_or_dlc_file( + return _ds_from_lp_or_dlc_file( file_path=file_path, source_software="LightningPose", fps=fps ) def from_dlc_file( - file_path: Union[Path, str], fps: Optional[float] = None + file_path: Path | str, fps: float | None = None ) -> xr.Dataset: - """Load pose tracking data from a DeepLabCut (DLC) output file. + """Create a ``movement`` poses dataset from a DeepLabCut file. Parameters ---------- file_path : pathlib.Path or str - Path to the file containing the DLC predicted poses, either in .h5 + Path to the file containing the predicted poses, either in .h5 or .csv format. fps : float, optional The number of frames per second in the video. If None (default), @@ -256,11 +333,12 @@ def from_dlc_file( Returns ------- xarray.Dataset - Dataset containing the pose tracks, confidence scores, and metadata. + ``movement`` dataset containing the pose tracks, confidence scores, + and associated metadata. See Also -------- - movement.io.load_poses.from_dlc_df : Load pose tracks from a DataFrame. + movement.io.load_poses.from_dlc_style_df Examples -------- @@ -268,22 +346,57 @@ def from_dlc_file( >>> ds = load_poses.from_dlc_file("path/to/file.h5", fps=30) """ - return _from_lp_or_dlc_file( + return _ds_from_lp_or_dlc_file( file_path=file_path, source_software="DeepLabCut", fps=fps ) -def _from_lp_or_dlc_file( - file_path: Union[Path, str], +def from_multiview_files( + file_path_dict: dict[str, Path | str], + source_software: Literal["DeepLabCut", "SLEAP", "LightningPose"], + fps: float | None = None, +) -> xr.Dataset: + """Load and merge pose tracking data from multiple views (cameras). + + Parameters + ---------- + file_path_dict : dict[str, Union[Path, str]] + A dict whose keys are the view names and values are the paths to load. + source_software : {'LightningPose', 'SLEAP', 'DeepLabCut'} + The source software of the file. + fps : float, optional + The number of frames per second in the video. If None (default), + the `time` coordinates will be in frame numbers. + + Returns + ------- + xarray.Dataset + Dataset containing the pose tracks, confidence scores, and metadata, + with an additional views dimension. + + """ + views_list = list(file_path_dict.keys()) + new_coord_views = xr.DataArray(views_list, dims="view") + + dataset_list = [ + from_file(f, source_software=source_software, fps=fps) + for f in file_path_dict.values() + ] + + return xr.concat(dataset_list, dim=new_coord_views) + + +def _ds_from_lp_or_dlc_file( + file_path: Path | str, source_software: Literal["LightningPose", "DeepLabCut"], - fps: Optional[float] = None, + fps: float | None = None, ) -> xr.Dataset: - """Load data from DeepLabCut (DLC) or LightningPose (LP) output files. + """Create a ``movement`` poses dataset from a LightningPose or DLC file. Parameters ---------- file_path : pathlib.Path or str - Path to the file containing the DLC predicted poses, either in .h5 + Path to the file containing the predicted poses, either in .h5 or .csv format. source_software : {'LightningPose', 'DeepLabCut'} The source software of the file. @@ -294,7 +407,8 @@ def _from_lp_or_dlc_file( Returns ------- xarray.Dataset - Dataset containing the pose tracks, confidence scores, and metadata. + ``movement`` dataset containing the pose tracks, confidence scores, + and associated metadata. """ expected_suffix = [".csv"] @@ -305,35 +419,28 @@ def _from_lp_or_dlc_file( file_path, expected_permission="r", expected_suffix=expected_suffix ) - # Load the DLC poses into a DataFrame + # Load the DeepLabCut poses into a DataFrame if file.path.suffix == ".csv": - df = _load_df_from_dlc_csv(file.path) + df = _df_from_dlc_csv(file.path) else: # file.path.suffix == ".h5" - df = _load_df_from_dlc_h5(file.path) + df = _df_from_dlc_h5(file.path) logger.debug(f"Loaded poses from {file.path} into a DataFrame.") # Convert the DataFrame to an xarray dataset - ds = from_dlc_df(df=df, fps=fps) + ds = from_dlc_style_df(df=df, fps=fps, source_software=source_software) # Add metadata as attrs - ds.attrs["source_software"] = source_software ds.attrs["source_file"] = file.path.as_posix() - # If source_software="LightningPose", we need to re-validate (because the - # validation call in from_dlc_df was run with source_software="DeepLabCut") - # This rerun enforces a single individual for LightningPose datasets. - if source_software == "LightningPose": - ds.move.validate() - logger.info(f"Loaded pose tracks from {file.path}:") logger.info(ds) return ds -def _load_from_sleap_analysis_file( - file_path: Path, fps: Optional[float] -) -> ValidPosesDataset: - """Load and validate data from a SLEAP analysis file. +def _ds_from_sleap_analysis_file( + file_path: Path, fps: float | None +) -> xr.Dataset: + """Create a ``movement`` poses dataset from a SLEAP analysis (.h5) file. Parameters ---------- @@ -345,8 +452,9 @@ def _load_from_sleap_analysis_file( Returns ------- - movement.io.tracks_validators.ValidPosesDataset - The validated pose tracks and confidence scores. + xarray.Dataset + ``movement`` dataset containing the pose tracks, confidence scores, + and associated metadata. """ file = ValidHDF5(file_path, expected_datasets=["tracks"]) @@ -367,7 +475,7 @@ def _load_from_sleap_analysis_file( # and transpose to shape: (n_frames, n_tracks, n_keypoints) if "point_scores" in f: scores = f["point_scores"][:].transpose((2, 0, 1)) - return ValidPosesDataset( + return from_numpy( position_array=tracks.astype(np.float32), confidence_array=scores.astype(np.float32), individual_names=individual_names, @@ -377,10 +485,10 @@ def _load_from_sleap_analysis_file( ) -def _load_from_sleap_labels_file( - file_path: Path, fps: Optional[float] -) -> ValidPosesDataset: - """Load and validate data from a SLEAP labels file. +def _ds_from_sleap_labels_file( + file_path: Path, fps: float | None +) -> xr.Dataset: + """Create a ``movement`` poses dataset from a SLEAP labels (.slp) file. Parameters ---------- @@ -392,8 +500,9 @@ def _load_from_sleap_labels_file( Returns ------- - movement.io.tracks_validators.ValidPosesDataset - The validated pose tracks and confidence scores. + xarray.Dataset + ``movement`` dataset containing the pose tracks, confidence scores, + and associated metadata. """ file = ValidHDF5(file_path, expected_datasets=["pred_points", "metadata"]) @@ -406,7 +515,7 @@ def _load_from_sleap_labels_file( "Assuming single-individual dataset and assigning " "default individual name." ) - return ValidPosesDataset( + return from_numpy( position_array=tracks_with_scores[:, :, :, :-1], confidence_array=tracks_with_scores[:, :, :, -1], individual_names=individual_names, @@ -417,9 +526,9 @@ def _load_from_sleap_labels_file( def _sleap_labels_to_numpy(labels: Labels) -> np.ndarray: - """Convert a SLEAP `Labels` object to a NumPy array. + """Convert a SLEAP ``Labels`` object to a NumPy array. - The output array contains pose tracks with point-wise confidence scores. + The output array contains pose tracks and point-wise confidence scores. Parameters ---------- @@ -484,18 +593,18 @@ def _sleap_labels_to_numpy(labels: Labels) -> np.ndarray: return tracks -def _load_df_from_dlc_csv(file_path: Path) -> pd.DataFrame: - """Parse a DeepLabCut-style .csv file into a pandas DataFrame. +def _df_from_dlc_csv(file_path: Path) -> pd.DataFrame: + """Create a DeepLabCut-style DataFrame from a .csv file. - If poses are loaded from a DeepLabCut .csv file, the DataFrame + If poses are loaded from a DeepLabCut-style .csv file, the DataFrame lacks the multi-index columns that are present in the .h5 file. This - function parses the .csv file to a pandas DataFrame with multi-index - columns, i.e. the same format as in the .h5 file. + function parses the .csv file to DataFrame with multi-index columns, + i.e. the same format as in the .h5 file. Parameters ---------- file_path : pathlib.Path - Path to the DeepLabCut-style .csv file. + Path to the DeepLabCut-style .csv file containing pose tracks. Returns ------- @@ -517,10 +626,12 @@ def _load_df_from_dlc_csv(file_path: Path) -> pd.DataFrame: # Form multi-index column names from the header lines level_names = [line[0] for line in header_lines] - column_tuples = list(zip(*[line[1:] for line in header_lines])) + column_tuples = list( + zip(*[line[1:] for line in header_lines], strict=False) + ) columns = pd.MultiIndex.from_tuples(column_tuples, names=level_names) - # Import the DLC poses as a DataFrame + # Import the DeepLabCut poses as a DataFrame df = pd.read_csv( file.path, skiprows=len(header_lines), @@ -531,8 +642,8 @@ def _load_df_from_dlc_csv(file_path: Path) -> pd.DataFrame: return df -def _load_df_from_dlc_h5(file_path: Path) -> pd.DataFrame: - """Load data from a DeepLabCut .h5 file into a pandas DataFrame. +def _df_from_dlc_h5(file_path: Path) -> pd.DataFrame: + """Create a DeepLabCut-style DataFrame from a .h5 file. Parameters ---------- @@ -542,7 +653,7 @@ def _load_df_from_dlc_h5(file_path: Path) -> pd.DataFrame: Returns ------- pandas.DataFrame - DeepLabCut-style Dataframe. + DeepLabCut-style DataFrame with multi-index columns. """ file = ValidHDF5(file_path, expected_datasets=["df_with_missing"]) @@ -552,8 +663,8 @@ def _load_df_from_dlc_h5(file_path: Path) -> pd.DataFrame: return df -def _from_valid_data(data: ValidPosesDataset) -> xr.Dataset: - """Convert already validated pose tracking data to an xarray Dataset. +def _ds_from_valid_data(data: ValidPosesDataset) -> xr.Dataset: + """Create a ``movement`` poses dataset from validated pose tracking data. Parameters ---------- @@ -563,7 +674,8 @@ def _from_valid_data(data: ValidPosesDataset) -> xr.Dataset: Returns ------- xarray.Dataset - Dataset containing the pose tracks, confidence scores, and metadata. + ``movement`` dataset containing the pose tracks, confidence scores, + and associated metadata. """ n_frames = data.position_array.shape[0] @@ -576,7 +688,7 @@ def _from_valid_data(data: ValidPosesDataset) -> xr.Dataset: time_coords = time_coords / data.fps time_unit = "seconds" - DIM_NAMES = MovementDataset.dim_names + DIM_NAMES = ValidPosesDataset.DIM_NAMES # Convert data to an xarray.Dataset return xr.Dataset( data_vars={ @@ -594,7 +706,8 @@ def _from_valid_data(data: ValidPosesDataset) -> xr.Dataset: attrs={ "fps": data.fps, "time_unit": time_unit, - "source_software": None, + "source_software": data.source_software, "source_file": None, + "ds_type": "poses", }, ) diff --git a/movement/io/save_poses.py b/movement/io/save_poses.py index 63ce689b..c47d28f1 100644 --- a/movement/io/save_poses.py +++ b/movement/io/save_poses.py @@ -1,29 +1,33 @@ -"""Functions for saving pose tracking data to various file formats.""" +"""Save pose tracking data from ``movement`` to various file formats.""" import logging from pathlib import Path -from typing import Literal, Union +from typing import Literal import h5py import numpy as np import pandas as pd import xarray as xr -from movement.io.validators import ValidFile -from movement.logging import log_error +from movement.utils.logging import log_error +from movement.validators.datasets import ValidPosesDataset +from movement.validators.files import ValidFile logger = logging.getLogger(__name__) -def _xarray_to_dlc_df(ds: xr.Dataset, columns: pd.MultiIndex) -> pd.DataFrame: - """Convert an xarray dataset to DLC-style multi-index pandas DataFrame. +def _ds_to_dlc_style_df( + ds: xr.Dataset, columns: pd.MultiIndex +) -> pd.DataFrame: + """Convert a ``movement`` dataset to a DeepLabCut-style DataFrame. Parameters ---------- ds : xarray.Dataset - Dataset containing pose tracks, confidence scores, and metadata. + ``movement`` dataset containing pose tracks, confidence scores, + and associated metadata. columns : pandas.MultiIndex - DLC-style multi-index columns + DeepLabCut-style multi-index columns Returns ------- @@ -57,7 +61,7 @@ def _auto_split_individuals(ds: xr.Dataset) -> bool: def _save_dlc_df(filepath: Path, df: pd.DataFrame) -> None: - """Given a filepath, will save the dataframe as either a .h5 or .csv. + """Save the dataframe as either a .h5 or .csv depending on the file path. Parameters ---------- @@ -74,20 +78,20 @@ def _save_dlc_df(filepath: Path, df: pd.DataFrame) -> None: df.to_hdf(filepath, key="df_with_missing") -def to_dlc_df( +def to_dlc_style_df( ds: xr.Dataset, split_individuals: bool = False -) -> Union[pd.DataFrame, dict[str, pd.DataFrame]]: - """Convert an xarray dataset to DeepLabCut-style pandas DataFrame(s). +) -> pd.DataFrame | dict[str, pd.DataFrame]: + """Convert a ``movement`` dataset to DeepLabCut-style DataFrame(s). Parameters ---------- ds : xarray.Dataset - Dataset containing pose tracks, confidence scores, and metadata. + ``movement`` dataset containing pose tracks, confidence scores, + and associated metadata. split_individuals : bool, optional - If True, return a dictionary of pandas DataFrames per individual, - with individual names as keys and DataFrames as values. - If False, return a single pandas DataFrame for all individuals. - Default is False. + If True, return a dictionary of DataFrames per individual, with + individual names as keys. If False (default), return a single + DataFrame for all individuals (see Notes). Returns ------- @@ -107,8 +111,7 @@ def to_dlc_df( See Also -------- - to_dlc_file : Save the xarray dataset containing pose tracks directly - to a DeepLabCut-style .h5 or .csv file. + to_dlc_file : Save dataset directly to a DeepLabCut-style .h5 or .csv file. """ _validate_dataset(ds) @@ -128,7 +131,7 @@ def to_dlc_df( [scorer, bodyparts, coords], names=index_levels ) - df = _xarray_to_dlc_df(individual_data, columns) + df = _ds_to_dlc_style_df(individual_data, columns) df_dict[individual] = df logger.info( @@ -142,46 +145,46 @@ def to_dlc_df( [scorer, individuals, bodyparts, coords], names=index_levels ) - df_all = _xarray_to_dlc_df(ds, columns) + df_all = _ds_to_dlc_style_df(ds, columns) - logger.info("Converted poses dataset to DLC-style DataFrame.") + logger.info("Converted poses dataset to DeepLabCut-style DataFrame.") return df_all def to_dlc_file( ds: xr.Dataset, - file_path: Union[str, Path], - split_individuals: Union[bool, Literal["auto"]] = "auto", + file_path: str | Path, + split_individuals: bool | Literal["auto"] = "auto", ) -> None: - """Save the xarray dataset to a DeepLabCut-style .h5 or .csv file. + """Save a ``movement`` dataset to DeepLabCut file(s). Parameters ---------- ds : xarray.Dataset - Dataset containing pose tracks, confidence scores, and metadata. + ``movement`` dataset containing pose tracks, confidence scores, + and associated metadata. file_path : pathlib.Path or str - Path to the file to save the DLC poses to. The file extension + Path to the file to save the poses to. The file extension must be either .h5 (recommended) or .csv. split_individuals : bool or "auto", optional - Whether to save individuals to separate files or to the same file.\n - If True, each individual will be saved to a separate file, - formatted as in a single-animal DeepLabCut project - i.e. without - the "individuals" column level. The individual's name will be appended - to the file path, just before the file extension, i.e. - "/path/to/filename_individual1.h5".\n - If False, all individuals will be saved to the same file, - formatted as in a multi-animal DeepLabCut project - i.e. the columns - will include the "individuals" level. The file path will not be - modified.\n - If "auto" the argument's value is determined based on the number of - individuals in the dataset: True if there is only one, and - False if there are more than one. This is the default. + Whether to save individuals to separate files or to the same file + (see Notes). Defaults to "auto". + + Notes + ----- + If ``split_individuals`` is True, each individual will be saved to a + separate file, formatted as in a single-animal DeepLabCut project + (without the "individuals" column level). The individual's name will be + appended to the file path, just before the file extension, e.g. + "/path/to/filename_individual1.h5". If False, all individuals will be + saved to the same file, formatted as in a multi-animal DeepLabCut project + (with the "individuals" column level). The file path will not be modified. + If "auto", the argument's value is determined based on the number of + individuals in the dataset: True if there is only one, False otherwise. See Also -------- - to_dlc_df : Convert an xarray dataset containing pose tracks into a single - DeepLabCut-style pandas DataFrame or a dictionary of DataFrames - per individual. + to_dlc_style_df : Convert dataset to DeepLabCut-style DataFrame(s). Examples -------- @@ -205,7 +208,7 @@ def to_dlc_file( if split_individuals: # split the dataset into a dictionary of dataframes per individual - df_dict = to_dlc_df(ds, split_individuals=True) + df_dict = to_dlc_style_df(ds, split_individuals=True) for key, df in df_dict.items(): # the key is the individual's name @@ -215,7 +218,7 @@ def to_dlc_file( logger.info(f"Saved poses for individual {key} to {file.path}.") else: # convert the dataset to a single dataframe for all individuals - df_all = to_dlc_df(ds, split_individuals=False) + df_all = to_dlc_style_df(ds, split_individuals=False) if isinstance(df_all, pd.DataFrame): _save_dlc_df(file.path, df_all) logger.info(f"Saved poses dataset to {file.path}.") @@ -223,30 +226,31 @@ def to_dlc_file( def to_lp_file( ds: xr.Dataset, - file_path: Union[str, Path], + file_path: str | Path, ) -> None: - """Save the xarray dataset to a LightningPose-style .csv file (see Notes). + """Save a ``movement`` dataset to a LightningPose file. Parameters ---------- ds : xarray.Dataset - Dataset containing pose tracks, confidence scores, and metadata. + ``movement`` dataset containing pose tracks, confidence scores, + and associated metadata. file_path : pathlib.Path or str - Path to the .csv file to save the poses to. + Path to the file to save the poses to. File extension must be .csv. Notes ----- LightningPose saves pose estimation outputs as .csv files, using the same format as single-animal DeepLabCut projects. Therefore, under the hood, - this function calls ``to_dlc_file`` with ``split_individuals=True``. This - setting means that each individual is saved to a separate file, with - the individual's name appended to the file path, just before the file - extension, i.e. "/path/to/filename_individual1.csv". + this function calls :func:`movement.io.save_poses.to_dlc_file` + with ``split_individuals=True``. This setting means that each individual + is saved to a separate file, with the individual's name appended to the + file path, just before the file extension, + i.e. "/path/to/filename_individual1.csv". See Also -------- - to_dlc_file : Save the xarray dataset containing pose tracks to a - DeepLabCut-style .h5 or .csv file. + to_dlc_file : Save dataset to a DeepLabCut-style .h5 or .csv file. """ file = _validate_file_path(file_path=file_path, expected_suffix=[".csv"]) @@ -254,17 +258,16 @@ def to_lp_file( to_dlc_file(ds, file.path, split_individuals=True) -def to_sleap_analysis_file( - ds: xr.Dataset, file_path: Union[str, Path] -) -> None: - """Save the xarray dataset to a SLEAP-style .h5 analysis file. +def to_sleap_analysis_file(ds: xr.Dataset, file_path: str | Path) -> None: + """Save a ``movement`` dataset to a SLEAP analysis file. Parameters ---------- ds : xarray.Dataset - Dataset containing pose tracks, confidence scores, and metadata. + ``movement`` dataset containing pose tracks, confidence scores, + and associated metadata. file_path : pathlib.Path or str - Path to the file to save the poses to. The file extension must be .h5. + Path to the file to save the poses to. File extension must be .h5. Notes ----- @@ -355,12 +358,13 @@ def to_sleap_analysis_file( def _remove_unoccupied_tracks(ds: xr.Dataset): - """Remove tracks that are completely unoccupied in the xarray dataset. + """Remove tracks that are completely unoccupied from the dataset. Parameters ---------- ds : xarray.Dataset - Dataset containing pose tracks, confidence scores, and metadata. + ``movement`` dataset containing pose tracks, confidence scores, + and associated metadata. Returns ------- @@ -373,7 +377,7 @@ def _remove_unoccupied_tracks(ds: xr.Dataset): def _validate_file_path( - file_path: Union[str, Path], expected_suffix: list[str] + file_path: str | Path, expected_suffix: list[str] ) -> ValidFile: """Validate the input file path. @@ -412,7 +416,7 @@ def _validate_file_path( def _validate_dataset(ds: xr.Dataset) -> None: - """Validate the input dataset is an xarray Dataset with valid poses. + """Validate the input as a proper ``movement`` dataset. Parameters ---------- @@ -421,12 +425,25 @@ def _validate_dataset(ds: xr.Dataset) -> None: Raises ------ + TypeError + If the input is not an xarray Dataset. ValueError - If `ds` is not an xarray Dataset with valid poses. + If the dataset is missing required data variables or dimensions. """ if not isinstance(ds, xr.Dataset): raise log_error( - ValueError, f"Expected an xarray Dataset, but got {type(ds)}." + TypeError, f"Expected an xarray Dataset, but got {type(ds)}." ) - ds.move.validate() # validate the dataset + + missing_vars = set(ValidPosesDataset.VAR_NAMES) - set(ds.data_vars) + if missing_vars: + raise ValueError( + f"Missing required data variables: {sorted(missing_vars)}" + ) # sort for a reproducible error message + + missing_dims = set(ValidPosesDataset.DIM_NAMES) - set(ds.dims) + if missing_dims: + raise ValueError( + f"Missing required dimensions: {sorted(missing_dims)}" + ) # sort for a reproducible error message diff --git a/movement/io/validators.py b/movement/io/validators.py deleted file mode 100644 index 38c11cdd..00000000 --- a/movement/io/validators.py +++ /dev/null @@ -1,373 +0,0 @@ -"""`attrs` classes for validating file paths and data structures.""" - -import os -from collections.abc import Iterable -from pathlib import Path -from typing import Any, Literal, Optional, Union - -import h5py -import numpy as np -from attrs import converters, define, field, validators - -from movement.logging import log_error, log_warning - - -@define -class ValidFile: - """Class for validating file paths. - - Parameters - ---------- - path : str or pathlib.Path - Path to the file. - expected_permission : {'r', 'w', 'rw'} - Expected access permission(s) for the file. If 'r', the file is - expected to be readable. If 'w', the file is expected to be writable. - If 'rw', the file is expected to be both readable and writable. - Default: 'r'. - expected_suffix : list of str - Expected suffix(es) for the file. If an empty list (default), this - check is skipped. - - Raises - ------ - IsADirectoryError - If the path points to a directory. - PermissionError - If the file does not have the expected access permission(s). - FileNotFoundError - If the file does not exist when `expected_permission` is 'r' or 'rw'. - FileExistsError - If the file exists when `expected_permission` is 'w'. - ValueError - If the file does not have one of the expected suffix(es). - - """ - - path: Path = field(converter=Path, validator=validators.instance_of(Path)) - expected_permission: Literal["r", "w", "rw"] = field( - default="r", validator=validators.in_(["r", "w", "rw"]), kw_only=True - ) - expected_suffix: list[str] = field(factory=list, kw_only=True) - - @path.validator - def path_is_not_dir(self, attribute, value): - """Ensure that the path does not point to a directory.""" - if value.is_dir(): - raise log_error( - IsADirectoryError, - f"Expected a file path but got a directory: {value}.", - ) - - @path.validator - def file_exists_when_expected(self, attribute, value): - """Ensure that the file exists (or not) as needed. - - This depends on the expected usage (read and/or write). - """ - if "r" in self.expected_permission: - if not value.exists(): - raise log_error( - FileNotFoundError, f"File {value} does not exist." - ) - else: # expected_permission is 'w' - if value.exists(): - raise log_error( - FileExistsError, f"File {value} already exists." - ) - - @path.validator - def file_has_access_permissions(self, attribute, value): - """Ensure that the file has the expected access permission(s). - - Raises a PermissionError if not. - """ - file_is_readable = os.access(value, os.R_OK) - parent_is_writeable = os.access(value.parent, os.W_OK) - if ("r" in self.expected_permission) and (not file_is_readable): - raise log_error( - PermissionError, - f"Unable to read file: {value}. " - "Make sure that you have read permissions.", - ) - if ("w" in self.expected_permission) and (not parent_is_writeable): - raise log_error( - PermissionError, - f"Unable to write to file: {value}. " - "Make sure that you have write permissions.", - ) - - @path.validator - def file_has_expected_suffix(self, attribute, value): - """Ensure that the file has one of the expected suffix(es).""" - if self.expected_suffix and value.suffix not in self.expected_suffix: - raise log_error( - ValueError, - f"Expected file with suffix(es) {self.expected_suffix} " - f"but got suffix {value.suffix} instead.", - ) - - -@define -class ValidHDF5: - """Class for validating HDF5 files. - - Parameters - ---------- - path : pathlib.Path - Path to the HDF5 file. - expected_datasets : list of str or None - List of names of the expected datasets in the HDF5 file. If an empty - list (default), this check is skipped. - - Raises - ------ - ValueError - If the file is not in HDF5 format or if it does not contain the - expected datasets. - - """ - - path: Path = field(validator=validators.instance_of(Path)) - expected_datasets: list[str] = field(factory=list, kw_only=True) - - @path.validator - def file_is_h5(self, attribute, value): - """Ensure that the file is indeed in HDF5 format.""" - try: - with h5py.File(value, "r") as f: - f.close() - except Exception as e: - raise log_error( - ValueError, - f"File {value} does not seem to be in valid" "HDF5 format.", - ) from e - - @path.validator - def file_contains_expected_datasets(self, attribute, value): - """Ensure that the HDF5 file contains the expected datasets.""" - if self.expected_datasets: - with h5py.File(value, "r") as f: - diff = set(self.expected_datasets).difference(set(f.keys())) - if len(diff) > 0: - raise log_error( - ValueError, - f"Could not find the expected dataset(s) {diff} " - f"in file: {value}. ", - ) - - -@define -class ValidDeepLabCutCSV: - """Class for validating DLC-style .csv files. - - Parameters - ---------- - path : pathlib.Path - Path to the .csv file. - - Raises - ------ - ValueError - If the .csv file does not contain the expected DeepLabCut index column - levels among its top rows. - - """ - - path: Path = field(validator=validators.instance_of(Path)) - - @path.validator - def csv_file_contains_expected_levels(self, attribute, value): - """Ensure that the .csv file contains the expected index column levels. - - These are to be found among the top 4 rows of the file. - """ - expected_levels = ["scorer", "bodyparts", "coords"] - - with open(value) as f: - top4_row_starts = [f.readline().split(",")[0] for _ in range(4)] - - if top4_row_starts[3].isdigit(): - # if 4th row starts with a digit, assume single-animal DLC file - expected_levels.append(top4_row_starts[3]) - else: - # otherwise, assume multi-animal DLC file - expected_levels.insert(1, "individuals") - - if top4_row_starts != expected_levels: - raise log_error( - ValueError, - ".csv header rows do not match the known format for " - "DeepLabCut pose estimation output files.", - ) - - -def _list_of_str(value: Union[str, Iterable[Any]]) -> list[str]: - """Try to coerce the value into a list of strings.""" - if isinstance(value, str): - log_warning( - f"Invalid value ({value}). Expected a list of strings. " - "Converting to a list of length 1." - ) - return [value] - elif isinstance(value, Iterable): - return [str(item) for item in value] - else: - raise log_error( - ValueError, f"Invalid value ({value}). Expected a list of strings." - ) - - -def _ensure_type_ndarray(value: Any) -> None: - """Raise ValueError the value is a not numpy array.""" - if not isinstance(value, np.ndarray): - raise log_error( - ValueError, f"Expected a numpy array, but got {type(value)}." - ) - - -def _set_fps_to_none_if_invalid(fps: Optional[float]) -> Optional[float]: - """Set fps to None if a non-positive float is passed.""" - if fps is not None and fps <= 0: - log_warning( - f"Invalid fps value ({fps}). Expected a positive number. " - "Setting fps to None." - ) - return None - return fps - - -def _validate_list_length( - attribute: str, value: Optional[list], expected_length: int -): - """Raise a ValueError if the list does not have the expected length.""" - if (value is not None) and (len(value) != expected_length): - raise log_error( - ValueError, - f"Expected `{attribute}` to have length {expected_length}, " - f"but got {len(value)}.", - ) - - -@define(kw_only=True) -class ValidPosesDataset: - """Class for validating pose tracking data imported from a file. - - Attributes - ---------- - position_array : np.ndarray - Array of shape (n_frames, n_individuals, n_keypoints, n_space) - containing the poses. It will be converted to a - `xarray.DataArray` object named "position". - confidence_array : np.ndarray, optional - Array of shape (n_frames, n_individuals, n_keypoints) containing - the point-wise confidence scores. It will be converted to a - `xarray.DataArray` object named "confidence". - If None (default), the scores will be set to an array of NaNs. - individual_names : list of str, optional - List of unique names for the individuals in the video. If None - (default), the individuals will be named "individual_0", - "individual_1", etc. - keypoint_names : list of str, optional - List of unique names for the keypoints in the skeleton. If None - (default), the keypoints will be named "keypoint_0", "keypoint_1", - etc. - fps : float, optional - Frames per second of the video. Defaults to None. - source_software : str, optional - Name of the software from which the poses were loaded. - Defaults to None. - - """ - - # Define class attributes - position_array: np.ndarray = field() - confidence_array: Optional[np.ndarray] = field(default=None) - individual_names: Optional[list[str]] = field( - default=None, - converter=converters.optional(_list_of_str), - ) - keypoint_names: Optional[list[str]] = field( - default=None, - converter=converters.optional(_list_of_str), - ) - fps: Optional[float] = field( - default=None, - converter=converters.pipe( # type: ignore - converters.optional(float), _set_fps_to_none_if_invalid - ), - ) - source_software: Optional[str] = field( - default=None, - validator=validators.optional(validators.instance_of(str)), - ) - - # Add validators - @position_array.validator - def _validate_position_array(self, attribute, value): - _ensure_type_ndarray(value) - if value.ndim != 4: - raise log_error( - ValueError, - f"Expected `{attribute}` to have 4 dimensions, " - f"but got {value.ndim}.", - ) - if value.shape[-1] not in [2, 3]: - raise log_error( - ValueError, - f"Expected `{attribute}` to have 2 or 3 spatial dimensions, " - f"but got {value.shape[-1]}.", - ) - - @confidence_array.validator - def _validate_confidence_array(self, attribute, value): - if value is not None: - _ensure_type_ndarray(value) - expected_shape = self.position_array.shape[:-1] - if value.shape != expected_shape: - raise log_error( - ValueError, - f"Expected `{attribute}` to have shape " - f"{expected_shape}, but got {value.shape}.", - ) - - @individual_names.validator - def _validate_individual_names(self, attribute, value): - if self.source_software == "LightningPose": - # LightningPose only supports a single individual - _validate_list_length(attribute, value, 1) - else: - _validate_list_length( - attribute, value, self.position_array.shape[1] - ) - - @keypoint_names.validator - def _validate_keypoint_names(self, attribute, value): - _validate_list_length(attribute, value, self.position_array.shape[2]) - - def __attrs_post_init__(self): - """Assign default values to optional attributes (if None).""" - if self.confidence_array is None: - self.confidence_array = np.full( - (self.position_array.shape[:-1]), np.nan, dtype="float32" - ) - log_warning( - "Confidence array was not provided." - "Setting to an array of NaNs." - ) - if self.individual_names is None: - self.individual_names = [ - f"individual_{i}" for i in range(self.position_array.shape[1]) - ] - log_warning( - "Individual names were not provided. " - f"Setting to {self.individual_names}." - ) - if self.keypoint_names is None: - self.keypoint_names = [ - f"keypoint_{i}" for i in range(self.position_array.shape[2]) - ] - log_warning( - "Keypoint names were not provided. " - f"Setting to {self.keypoint_names}." - ) diff --git a/movement/kinematics.py b/movement/kinematics.py new file mode 100644 index 00000000..12e1514f --- /dev/null +++ b/movement/kinematics.py @@ -0,0 +1,863 @@ +"""Compute kinematic variables like velocity and acceleration.""" + +import itertools +from typing import Literal + +import numpy as np +import xarray as xr +from scipy.spatial.distance import cdist + +from movement.utils.logging import log_error, log_warning +from movement.utils.reports import report_nan_values +from movement.utils.vector import compute_norm +from movement.validators.arrays import validate_dims_coords + + +def compute_displacement(data: xr.DataArray) -> xr.DataArray: + """Compute displacement array in cartesian coordinates. + + The displacement array is defined as the difference between the position + array at time point ``t`` and the position array at time point ``t-1``. + + As a result, for a given individual and keypoint, the displacement vector + at time point ``t``, is the vector pointing from the previous + ``(t-1)`` to the current ``(t)`` position, in cartesian coordinates. + + Parameters + ---------- + data : xarray.DataArray + The input data containing position information, with ``time`` + and ``space`` (in Cartesian coordinates) as required dimensions. + + Returns + ------- + xarray.DataArray + An xarray DataArray containing displacement vectors in cartesian + coordinates. + + Notes + ----- + For the ``position`` array of a ``poses`` dataset, the ``displacement`` + array will hold the displacement vectors for every keypoint and every + individual. + + For the ``position`` array of a ``bboxes`` dataset, the ``displacement`` + array will hold the displacement vectors for the centroid of every + individual bounding box. + + For the ``shape`` array of a ``bboxes`` dataset, the + ``displacement`` array will hold vectors with the change in width and + height per bounding box, between consecutive time points. + + """ + validate_dims_coords(data, {"time": [], "space": []}) + result = data.diff(dim="time") + result = result.reindex(data.coords, fill_value=0) + return result + + +def compute_velocity(data: xr.DataArray) -> xr.DataArray: + """Compute velocity array in cartesian coordinates. + + The velocity array is the first time-derivative of the position + array. It is computed by applying the second-order accurate central + differences method on the position array. + + Parameters + ---------- + data : xarray.DataArray + The input data containing position information, with ``time`` + and ``space`` (in Cartesian coordinates) as required dimensions. + + Returns + ------- + xarray.DataArray + An xarray DataArray containing velocity vectors in cartesian + coordinates. + + Notes + ----- + For the ``position`` array of a ``poses`` dataset, the ``velocity`` array + will hold the velocity vectors for every keypoint and every individual. + + For the ``position`` array of a ``bboxes`` dataset, the ``velocity`` array + will hold the velocity vectors for the centroid of every individual + bounding box. + + See Also + -------- + compute_time_derivative : The underlying function used. + + """ + # validate only presence of Cartesian space dimension + # (presence of time dimension will be checked in compute_time_derivative) + validate_dims_coords(data, {"space": []}) + return compute_time_derivative(data, order=1) + + +def compute_acceleration(data: xr.DataArray) -> xr.DataArray: + """Compute acceleration array in cartesian coordinates. + + The acceleration array is the second time-derivative of the + position array. It is computed by applying the second-order accurate + central differences method on the velocity array. + + Parameters + ---------- + data : xarray.DataArray + The input data containing position information, with ``time`` + and ``space`` (in Cartesian coordinates) as required dimensions. + + Returns + ------- + xarray.DataArray + An xarray DataArray containing acceleration vectors in cartesian + coordinates. + + Notes + ----- + For the ``position`` array of a ``poses`` dataset, the ``acceleration`` + array will hold the acceleration vectors for every keypoint and every + individual. + + For the ``position`` array of a ``bboxes`` dataset, the ``acceleration`` + array will hold the acceleration vectors for the centroid of every + individual bounding box. + + See Also + -------- + compute_time_derivative : The underlying function used. + + """ + # validate only presence of Cartesian space dimension + # (presence of time dimension will be checked in compute_time_derivative) + validate_dims_coords(data, {"space": []}) + return compute_time_derivative(data, order=2) + + +def compute_time_derivative(data: xr.DataArray, order: int) -> xr.DataArray: + """Compute the time-derivative of an array using numerical differentiation. + + This function uses :meth:`xarray.DataArray.differentiate`, + which differentiates the array with the second-order + accurate central differences method. + + Parameters + ---------- + data : xarray.DataArray + The input data containing ``time`` as a required dimension. + order : int + The order of the time-derivative. For an input containing position + data, use 1 to compute velocity, and 2 to compute acceleration. Value + must be a positive integer. + + Returns + ------- + xarray.DataArray + An xarray DataArray containing the time-derivative of the input data. + + See Also + -------- + :meth:`xarray.DataArray.differentiate` : The underlying method used. + + """ + if not isinstance(order, int): + raise log_error( + TypeError, f"Order must be an integer, but got {type(order)}." + ) + if order <= 0: + raise log_error(ValueError, "Order must be a positive integer.") + validate_dims_coords(data, {"time": []}) + result = data + for _ in range(order): + result = result.differentiate("time") + return result + + +def compute_speed(data: xr.DataArray) -> xr.DataArray: + """Compute instantaneous speed at each time point. + + Speed is a scalar quantity computed as the Euclidean norm (magnitude) + of the velocity vector at each time point. + + + Parameters + ---------- + data : xarray.DataArray + The input data containing position information, with ``time`` + and ``space`` (in Cartesian coordinates) as required dimensions. + + Returns + ------- + xarray.DataArray + An xarray DataArray containing the computed speed, + with dimensions matching those of the input data, + except ``space`` is removed. + + """ + return compute_norm(compute_velocity(data)) + + +def compute_forward_vector( + data: xr.DataArray, + left_keypoint: str, + right_keypoint: str, + camera_view: Literal["top_down", "bottom_up"] = "top_down", +): + """Compute a 2D forward vector given two left-right symmetric keypoints. + + The forward vector is computed as a vector perpendicular to the + line connecting two symmetrical keypoints on either side of the body + (i.e., symmetrical relative to the mid-sagittal plane), and pointing + forwards (in the rostral direction). A top-down or bottom-up view of the + animal is assumed (see Notes). + + Parameters + ---------- + data : xarray.DataArray + The input data representing position. This must contain + the two symmetrical keypoints located on the left and + right sides of the body, respectively. + left_keypoint : str + Name of the left keypoint, e.g., "left_ear" + right_keypoint : str + Name of the right keypoint, e.g., "right_ear" + camera_view : Literal["top_down", "bottom_up"], optional + The camera viewing angle, used to determine the upwards + direction of the animal. Can be either ``"top_down"`` (where the + upwards direction is [0, 0, -1]), or ``"bottom_up"`` (where the + upwards direction is [0, 0, 1]). If left unspecified, the camera + view is assumed to be ``"top_down"``. + + Returns + ------- + xarray.DataArray + An xarray DataArray representing the forward vector, with + dimensions matching the input data array, but without the + ``keypoints`` dimension. + + Notes + ----- + To determine the forward direction of the animal, we need to specify + (1) the right-to-left direction of the animal and (2) its upward direction. + We determine the right-to-left direction via the input left and right + keypoints. The upwards direction, in turn, can be determined by passing the + ``camera_view`` argument with either ``"top_down"`` or ``"bottom_up"``. If + the camera view is specified as being ``"top_down"``, or if no additional + information is provided, we assume that the upwards direction matches that + of the vector ``[0, 0, -1]``. If the camera view is ``"bottom_up"``, the + upwards direction is assumed to be given by ``[0, 0, 1]``. For both cases, + we assume that position values are expressed in the image coordinate + system (where the positive X-axis is oriented to the right, the positive + Y-axis faces downwards, and positive Z-axis faces away from the person + viewing the screen). + + If one of the required pieces of information is missing for a frame (e.g., + the left keypoint is not visible), then the computed head direction vector + is set to NaN. + + """ + # Validate input data + _validate_type_data_array(data) + validate_dims_coords( + data, + { + "time": [], + "keypoints": [left_keypoint, right_keypoint], + "space": [], + }, + ) + if len(data.space) != 2: + raise log_error( + ValueError, + "Input data must have exactly 2 spatial dimensions, but " + f"currently has {len(data.space)}.", + ) + + # Validate input keypoints + if left_keypoint == right_keypoint: + raise log_error( + ValueError, "The left and right keypoints may not be identical." + ) + + # Define right-to-left vector + right_to_left_vector = data.sel( + keypoints=left_keypoint, drop=True + ) - data.sel(keypoints=right_keypoint, drop=True) + + # Define upward vector + # default: negative z direction in the image coordinate system + if camera_view == "top_down": + upward_vector = np.array([0, 0, -1]) + else: + upward_vector = np.array([0, 0, 1]) + + upward_vector = xr.DataArray( + np.tile(upward_vector.reshape(1, -1), [len(data.time), 1]), + dims=["time", "space"], + ) + + # Compute forward direction as the cross product + # (right-to-left) cross (forward) = up + forward_vector = xr.cross( + right_to_left_vector, upward_vector, dim="space" + )[:, :, :-1] # keep only the first 2 dimensions of the result + + # Return unit vector + + return forward_vector / compute_norm(forward_vector) + + +def compute_head_direction_vector( + data: xr.DataArray, + left_keypoint: str, + right_keypoint: str, + camera_view: Literal["top_down", "bottom_up"] = "top_down", +): + """Compute the 2D head direction vector given two keypoints on the head. + + This function is an alias for :func:`compute_forward_vector()\ + `. For more + detailed information on how the head direction vector is computed, + please refer to the documentation for that function. + + Parameters + ---------- + data : xarray.DataArray + The input data representing position. This must contain + the two chosen keypoints corresponding to the left and + right of the head. + left_keypoint : str + Name of the left keypoint, e.g., "left_ear" + right_keypoint : str + Name of the right keypoint, e.g., "right_ear" + camera_view : Literal["top_down", "bottom_up"], optional + The camera viewing angle, used to determine the upwards + direction of the animal. Can be either ``"top_down"`` (where the + upwards direction is [0, 0, -1]), or ``"bottom_up"`` (where the + upwards direction is [0, 0, 1]). If left unspecified, the camera + view is assumed to be ``"top_down"``. + + Returns + ------- + xarray.DataArray + An xarray DataArray representing the head direction vector, with + dimensions matching the input data array, but without the + ``keypoints`` dimension. + + """ + return compute_forward_vector( + data, left_keypoint, right_keypoint, camera_view=camera_view + ) + + +def _cdist( + a: xr.DataArray, + b: xr.DataArray, + dim: Literal["individuals", "keypoints"], + metric: str | None = "euclidean", + **kwargs, +) -> xr.DataArray: + """Compute distances between two position arrays across a given dimension. + + This function is a wrapper around :func:`scipy.spatial.distance.cdist` + and computes the pairwise distances between the two input position arrays + across the dimension specified by ``dim``. + The dimension can be either ``individuals`` or ``keypoints``. + The distances are computed using the specified ``metric``. + + Parameters + ---------- + a : xarray.DataArray + The first input data containing position information of a + single individual or keypoint, with ``time``, ``space`` + (in Cartesian coordinates), and ``individuals`` or ``keypoints`` + (as specified by ``dim``) as required dimensions. + b : xarray.DataArray + The second input data containing position information of a + single individual or keypoint, with ``time``, ``space`` + (in Cartesian coordinates), and ``individuals`` or ``keypoints`` + (as specified by ``dim``) as required dimensions. + dim : str + The dimension to compute the distances for. Must be either + ``'individuals'`` or ``'keypoints'``. + metric : str, optional + The distance metric to use. Must be one of the options supported + by :func:`scipy.spatial.distance.cdist`, e.g. ``'cityblock'``, + ``'euclidean'``, etc. + Defaults to ``'euclidean'``. + **kwargs : dict + Additional keyword arguments to pass to + :func:`scipy.spatial.distance.cdist`. + + Returns + ------- + xarray.DataArray + An xarray DataArray containing the computed distances between + each pair of inputs. + + Examples + -------- + Compute the Euclidean distance (default) between ``ind1`` and + ``ind2`` (i.e. interindividual distance for all keypoints) + using the ``position`` data variable in the Dataset ``ds``: + + >>> pos1 = ds.position.sel(individuals="ind1") + >>> pos2 = ds.position.sel(individuals="ind2") + >>> ind_dists = _cdist(pos1, pos2, dim="individuals") + + Compute the Euclidean distance (default) between ``key1`` and + ``key2`` (i.e. interkeypoint distance for all individuals) + using the ``position`` data variable in the Dataset ``ds``: + + >>> pos1 = ds.position.sel(keypoints="key1") + >>> pos2 = ds.position.sel(keypoints="key2") + >>> key_dists = _cdist(pos1, pos2, dim="keypoints") + + See Also + -------- + scipy.spatial.distance.cdist : The underlying function used. + compute_pairwise_distances : Compute pairwise distances between + ``individuals`` or ``keypoints`` + + """ + # The dimension from which ``dim`` labels are obtained + labels_dim = "individuals" if dim == "keypoints" else "keypoints" + elem1 = getattr(a, dim).item() + elem2 = getattr(b, dim).item() + a = _validate_labels_dimension(a, labels_dim) + b = _validate_labels_dimension(b, labels_dim) + result = xr.apply_ufunc( + cdist, + a, + b, + kwargs={"metric": metric, **kwargs}, + input_core_dims=[[labels_dim, "space"], [labels_dim, "space"]], + output_core_dims=[[elem1, elem2]], + vectorize=True, + ) + result = result.assign_coords( + { + elem1: getattr(a, labels_dim).values, + elem2: getattr(a, labels_dim).values, + } + ) + # Drop any squeezed coordinates + return result.squeeze(drop=True) + + +def compute_pairwise_distances( + data: xr.DataArray, + dim: Literal["individuals", "keypoints"], + pairs: dict[str, str | list[str]] | Literal["all"], + metric: str | None = "euclidean", + **kwargs, +) -> xr.DataArray | dict[str, xr.DataArray]: + """Compute pairwise distances between ``individuals`` or ``keypoints``. + + This function computes the distances between + pairs of ``individuals`` (i.e. interindividual distances) or + pairs of ``keypoints`` (i.e. interkeypoint distances), + as determined by ``dim``. + The distances are computed for the given ``pairs`` + using the specified ``metric``. + + Parameters + ---------- + data : xarray.DataArray + The input data containing position information, with ``time``, + ``space`` (in Cartesian coordinates), and + ``individuals`` or ``keypoints`` (as specified by ``dim``) + as required dimensions. + dim : Literal["individuals", "keypoints"] + The dimension to compute the distances for. Must be either + ``'individuals'`` or ``'keypoints'``. + pairs : dict[str, str | list[str]] or 'all' + Specifies the pairs of elements (either individuals or keypoints) + for which to compute distances, depending on the value of ``dim``. + + - If ``dim='individuals'``, ``pairs`` should be a dictionary where + each key is an individual name, and each value is also an individual + name or a list of such names to compute distances with. + - If ``dim='keypoints'``, ``pairs`` should be a dictionary where each + key is a keypoint name, and each value is also keypoint name or a + list of such names to compute distances with. + - Alternatively, use the special keyword ``'all'`` to compute distances + for all possible pairs of individuals or keypoints + (depending on ``dim``). + metric : str, optional + The distance metric to use. Must be one of the options supported + by :func:`scipy.spatial.distance.cdist`, e.g. ``'cityblock'``, + ``'euclidean'``, etc. + Defaults to ``'euclidean'``. + **kwargs : dict + Additional keyword arguments to pass to + :func:`scipy.spatial.distance.cdist`. + + Returns + ------- + xarray.DataArray or dict[str, xarray.DataArray] + The computed pairwise distances. If a single pair is specified in + ``pairs``, returns an :class:`xarray.DataArray`. If multiple pairs + are specified, returns a dictionary where each key is a string + representing the pair (e.g., ``'dist_ind1_ind2'`` or + ``'dist_key1_key2'``) and each value is an :class:`xarray.DataArray` + containing the computed distances for that pair. + + Raises + ------ + ValueError + If ``dim`` is not one of ``'individuals'`` or ``'keypoints'``; + if ``pairs`` is not a dictionary or ``'all'``; or + if there are no pairs in ``data`` to compute distances for. + + Examples + -------- + Compute the Euclidean distance (default) between ``ind1`` and ``ind2`` + (i.e. interindividual distance), for all possible pairs of keypoints. + + >>> position = xr.DataArray( + ... np.arange(36).reshape(2, 3, 3, 2), + ... coords={ + ... "time": np.arange(2), + ... "individuals": ["ind1", "ind2", "ind3"], + ... "keypoints": ["key1", "key2", "key3"], + ... "space": ["x", "y"], + ... }, + ... dims=["time", "individuals", "keypoints", "space"], + ... ) + >>> dist_ind1_ind2 = compute_pairwise_distances( + ... position, "individuals", {"ind1": "ind2"} + ... ) + >>> dist_ind1_ind2 + Size: 144B + 8.485 11.31 14.14 5.657 8.485 11.31 ... 5.657 8.485 11.31 2.828 5.657 8.485 + Coordinates: + * time (time) int64 16B 0 1 + * ind1 (ind1) >> dist_ind1_ind2.sel(ind1="key1", ind2="key2") + + Compute the Euclidean distance (default) between ``key1`` and ``key2`` + (i.e. interkeypoint distance), for all possible pairs of individuals. + + >>> dist_key1_key2 = compute_pairwise_distances( + ... position, "keypoints", {"key1": "key2"} + ... ) + >>> dist_key1_key2 + Size: 144B + 2.828 11.31 19.8 5.657 2.828 11.31 14.14 ... 2.828 11.31 14.14 5.657 2.828 + Coordinates: + * time (time) int64 16B 0 1 + * key1 (key1) >> dist_key1_key2.sel(key1="ind1", key2="ind1") + + To obtain the distances between ``key1`` of ``ind1`` and + ``key2`` of ``ind2``: + + >>> dist_key1_key2.sel(key1="ind1", key2="ind2") + + Compute the city block or Manhattan distance for multiple pairs of + keypoints using ``position``: + + >>> key_dists = compute_pairwise_distances( + ... position, + ... "keypoints", + ... {"key1": "key2", "key3": ["key1", "key2"]}, + ... metric="cityblock", + ... ) + >>> key_dists.keys() + dict_keys(['dist_key1_key2', 'dist_key3_key1', 'dist_key3_key2']) + + As multiple pairs of keypoints are specified, + the resulting ``key_dists`` is a dictionary containing the DataArrays + of computed distances for each pair of keypoints. + + Compute the city block or Manhattan distance for all possible pairs of + individuals using ``position``: + + >>> ind_dists = compute_pairwise_distances( + ... position, + ... "individuals", + ... "all", + ... metric="cityblock", + ... ) + >>> ind_dists.keys() + dict_keys(['dist_ind1_ind2', 'dist_ind1_ind3', 'dist_ind2_ind3']) + + See Also + -------- + scipy.spatial.distance.cdist : The underlying function used. + + """ + if dim not in ["individuals", "keypoints"]: + raise log_error( + ValueError, + "'dim' must be either 'individuals' or 'keypoints', " + f"but got {dim}.", + ) + if isinstance(pairs, str) and pairs != "all": + raise log_error( + ValueError, + f"'pairs' must be a dictionary or 'all', but got {pairs}.", + ) + validate_dims_coords(data, {"time": [], "space": ["x", "y"], dim: []}) + # Find all possible pair combinations if 'all' is specified + if pairs == "all": + paired_elements = list( + itertools.combinations(getattr(data, dim).values, 2) + ) + else: + paired_elements = [ + (elem1, elem2) + for elem1, elem2_list in pairs.items() + for elem2 in + ( + # Ensure elem2_list is a list + [elem2_list] if isinstance(elem2_list, str) else elem2_list + ) + ] + if not paired_elements: + raise log_error( + ValueError, "Could not find any pairs to compute distances for." + ) + pairwise_distances = { + f"dist_{elem1}_{elem2}": _cdist( + data.sel({dim: elem1}), + data.sel({dim: elem2}), + dim=dim, + metric=metric, + **kwargs, + ) + for elem1, elem2 in paired_elements + } + # Return DataArray if result only has one key + if len(pairwise_distances) == 1: + return next(iter(pairwise_distances.values())) + return pairwise_distances + + +def _validate_labels_dimension(data: xr.DataArray, dim: str) -> xr.DataArray: + """Validate the input data contains the ``dim`` for labelling dimensions. + + This function ensures the input data contains the ``dim`` + used as labels (coordinates) when applying + :func:`scipy.spatial.distance.cdist` to + the input data, by adding a temporary dimension if necessary. + + Parameters + ---------- + data : xarray.DataArray + The input data to validate. + dim : str + The dimension to validate. + + Returns + ------- + xarray.DataArray + The input data with the labels dimension validated. + + """ + if data.coords.get(dim) is None: + data = data.assign_coords({dim: "temp_dim"}) + if data.coords[dim].ndim == 0: + data = data.expand_dims(dim).transpose("time", "space", dim) + return data + + +def _validate_type_data_array(data: xr.DataArray) -> None: + """Validate the input data is an xarray DataArray. + + Parameters + ---------- + data : xarray.DataArray + The input data to validate. + + Raises + ------ + ValueError + If the input data is not an xarray DataArray. + + """ + if not isinstance(data, xr.DataArray): + raise log_error( + TypeError, + f"Input data must be an xarray.DataArray, but got {type(data)}.", + ) + + +def compute_path_length( + data: xr.DataArray, + start: float | None = None, + stop: float | None = None, + nan_policy: Literal["ffill", "scale"] = "ffill", + nan_warn_threshold: float = 0.2, +) -> xr.DataArray: + """Compute the length of a path travelled between two time points. + + The path length is defined as the sum of the norms (magnitudes) of the + displacement vectors between two time points ``start`` and ``stop``, + which should be provided in the time units of the data array. + If not specified, the minimum and maximum time coordinates of the data + array are used as start and stop times, respectively. + + Parameters + ---------- + data : xarray.DataArray + The input data containing position information, with ``time`` + and ``space`` (in Cartesian coordinates) as required dimensions. + start : float, optional + The start time of the path. If None (default), + the minimum time coordinate in the data is used. + stop : float, optional + The end time of the path. If None (default), + the maximum time coordinate in the data is used. + nan_policy : Literal["ffill", "scale"], optional + Policy to handle NaN (missing) values. Can be one of the ``"ffill"`` + or ``"scale"``. Defaults to ``"ffill"`` (forward fill). + See Notes for more details on the two policies. + nan_warn_threshold : float, optional + If more than this proportion of values are missing in any point track, + a warning will be emitted. Defaults to 0.2 (20%). + + Returns + ------- + xarray.DataArray + An xarray DataArray containing the computed path length, + with dimensions matching those of the input data, + except ``time`` and ``space`` are removed. + + Notes + ----- + Choosing ``nan_policy="ffill"`` will use :meth:`xarray.DataArray.ffill` + to forward-fill missing segments (NaN values) across time. + This equates to assuming that a track remains stationary for + the duration of the missing segment and then instantaneously moves to + the next valid position, following a straight line. This approach tends + to underestimate the path length, and the error increases with the number + of missing values. + + Choosing ``nan_policy="scale"`` will adjust the path length based on the + the proportion of valid segments per point track. For example, if only + 80% of segments are present, the path length will be computed based on + these and the result will be divided by 0.8. This approach assumes + that motion dynamics are similar across observed and missing time + segments, which may not accurately reflect actual conditions. + + """ + validate_dims_coords(data, {"time": [], "space": []}) + data = data.sel(time=slice(start, stop)) + # Check that the data is not empty or too short + n_time = data.sizes["time"] + if n_time < 2: + raise log_error( + ValueError, + f"At least 2 time points are required to compute path length, " + f"but {n_time} were found. Double-check the start and stop times.", + ) + + _warn_about_nan_proportion(data, nan_warn_threshold) + + if nan_policy == "ffill": + return compute_norm( + compute_displacement(data.ffill(dim="time")).isel( + time=slice(1, None) + ) # skip first displacement (always 0) + ).sum(dim="time", min_count=1) # return NaN if no valid segment + + elif nan_policy == "scale": + return _compute_scaled_path_length(data) + else: + raise log_error( + ValueError, + f"Invalid value for nan_policy: {nan_policy}. " + "Must be one of 'ffill' or 'scale'.", + ) + + +def _warn_about_nan_proportion( + data: xr.DataArray, nan_warn_threshold: float +) -> None: + """Print a warning if the proportion of NaN values exceeds a threshold. + + The NaN proportion is evaluated per point track, and a given point is + considered NaN if any of its ``space`` coordinates are NaN. The warning + specifically lists the point tracks that exceed the threshold. + + Parameters + ---------- + data : xarray.DataArray + The input data array. + nan_warn_threshold : float + The threshold for the proportion of NaN values. Must be a number + between 0 and 1. + + """ + nan_warn_threshold = float(nan_warn_threshold) + if not 0 <= nan_warn_threshold <= 1: + raise log_error( + ValueError, + "nan_warn_threshold must be between 0 and 1.", + ) + + n_nans = data.isnull().any(dim="space").sum(dim="time") + data_to_warn_about = data.where( + n_nans > data.sizes["time"] * nan_warn_threshold, drop=True + ) + if len(data_to_warn_about) > 0: + log_warning( + "The result may be unreliable for point tracks with many " + "missing values. The following tracks have more than " + f"{nan_warn_threshold * 100:.3} % NaN values:", + ) + print(report_nan_values(data_to_warn_about)) + + +def _compute_scaled_path_length( + data: xr.DataArray, +) -> xr.DataArray: + """Compute scaled path length based on proportion of valid segments. + + Path length is first computed based on valid segments (non-NaN values + on both ends of the segment) and then scaled based on the proportion of + valid segments per point track - i.e. the result is divided by the + proportion of valid segments. + + Parameters + ---------- + data : xarray.DataArray + The input data containing position information, with ``time`` + and ``space`` (in Cartesian coordinates) as required dimensions. + + Returns + ------- + xarray.DataArray + An xarray DataArray containing the computed path length, + with dimensions matching those of the input data, + except ``time`` and ``space`` are removed. + + """ + # Skip first displacement segment (always 0) to not mess up the scaling + displacement = compute_displacement(data).isel(time=slice(1, None)) + # count number of valid displacement segments per point track + valid_segments = (~displacement.isnull()).all(dim="space").sum(dim="time") + # compute proportion of valid segments per point track + valid_proportion = valid_segments / (data.sizes["time"] - 1) + # return scaled path length + return compute_norm(displacement).sum(dim="time") / valid_proportion diff --git a/movement/move_accessor.py b/movement/move_accessor.py deleted file mode 100644 index 17c268dc..00000000 --- a/movement/move_accessor.py +++ /dev/null @@ -1,121 +0,0 @@ -"""Accessor for extending xarray.Dataset objects.""" - -import logging -from typing import ClassVar - -import xarray as xr - -from movement.analysis import kinematics -from movement.io.validators import ValidPosesDataset - -logger = logging.getLogger(__name__) - -# Preserve the attributes (metadata) of xarray objects after operations -xr.set_options(keep_attrs=True) - - -@xr.register_dataset_accessor("move") -class MovementDataset: - """An :py:class:`xarray.Dataset` accessor for pose tracking data. - - A ``movement`` dataset is an :py:class:`xarray.Dataset` with a specific - structure to represent pose tracks, associated confidence scores and - relevant metadata. - - Methods/properties that extend the standard ``xarray`` functionality are - defined in this class. To avoid conflicts with ``xarray``'s namespace, - ``movement``-specific methods are accessed using the ``move`` keyword, - for example ``ds.move.validate()`` (see [1]_ for more details). - - - References - ---------- - .. [1] https://docs.xarray.dev/en/stable/internals/extending-xarray.html - - """ - - # Names of the expected dimensions in the dataset - dim_names: ClassVar[tuple] = ( - "time", - "individuals", - "keypoints", - "space", - ) - - # Names of the expected data variables in the dataset - var_names: ClassVar[tuple] = ( - "position", - "confidence", - ) - - def __init__(self, ds: xr.Dataset): - """Initialize the MovementDataset.""" - self._obj = ds - - def __getattr__(self, name: str) -> xr.DataArray: - """Forward requested but undefined attributes to relevant modules. - - This method currently only forwards kinematic property computation - to the respective functions in the ``kinematics`` module. - - Parameters - ---------- - name : str - The name of the attribute to get. - - Returns - ------- - xarray.DataArray - The computed attribute value. - - Raises - ------ - AttributeError - If the attribute does not exist. - - """ - - def method(*args, **kwargs): - if name.startswith("compute_") and hasattr(kinematics, name): - self.validate() - return getattr(kinematics, name)( - self._obj.position, *args, **kwargs - ) - raise AttributeError( - f"'{self.__class__.__name__}' object has no attribute '{name}'" - ) - - return method - - def validate(self) -> None: - """Validate the dataset. - - This method checks if the dataset contains the expected dimensions, - data variables, and metadata attributes. It also ensures that the - dataset contains valid poses. - """ - fps = self._obj.attrs.get("fps", None) - source_software = self._obj.attrs.get("source_software", None) - try: - missing_dims = set(self.dim_names) - set(self._obj.dims) - missing_vars = set(self.var_names) - set(self._obj.data_vars) - if missing_dims: - raise ValueError( - f"Missing required dimensions: {missing_dims}" - ) - if missing_vars: - raise ValueError( - f"Missing required data variables: {missing_vars}" - ) - ValidPosesDataset( - position_array=self._obj[self.var_names[0]].values, - confidence_array=self._obj[self.var_names[1]].values, - individual_names=self._obj.coords[self.dim_names[1]].values, - keypoint_names=self._obj.coords[self.dim_names[2]].values, - fps=fps, - source_software=source_software, - ) - except Exception as e: - error_msg = "The dataset does not contain valid poses." - logger.error(error_msg) - raise ValueError(error_msg) from e diff --git a/movement/sample_data.py b/movement/sample_data.py index 29875845..67a228cf 100644 --- a/movement/sample_data.py +++ b/movement/sample_data.py @@ -1,4 +1,4 @@ -"""Module for fetching and loading sample datasets. +"""Fetch and load sample datasets. This module provides functions for fetching and loading sample data used in tests, examples, and tutorials. The data are stored in a remote repository @@ -14,8 +14,8 @@ import yaml from requests.exceptions import RequestException -from movement.io import load_poses -from movement.logging import log_error, log_warning +from movement.io import load_bboxes, load_poses +from movement.utils.logging import log_error, log_warning logger = logging.getLogger(__name__) @@ -87,7 +87,7 @@ def _fetch_metadata( ------- dict A dictionary containing metadata for each sample dataset, with the - dataset name (pose file name) as the key. + dataset file name as the key. """ local_file_path = Path(data_dir / file_name) @@ -116,7 +116,8 @@ def _fetch_metadata( def _generate_file_registry(metadata: dict[str, dict]) -> dict[str, str]: """Generate a file registry based on the contents of the metadata. - This includes files containing poses, frames, or entire videos. + This includes files containing poses, frames, videos, or bounding boxes + data. Parameters ---------- @@ -131,7 +132,7 @@ def _generate_file_registry(metadata: dict[str, dict]) -> dict[str, str]: """ file_registry = {} for ds, val in metadata.items(): - file_registry[f"poses/{ds}"] = val["sha256sum"] + file_registry[f"{val['type']}/{ds}"] = val["sha256sum"] for key in ["video", "frame"]: file_name = val[key]["file_name"] if file_name: @@ -139,7 +140,7 @@ def _generate_file_registry(metadata: dict[str, dict]) -> dict[str, str]: return file_registry -# Create a download manager for the pose data +# Create a download manager for the sample data metadata = _fetch_metadata(METADATA_FILE, DATA_DIR) file_registry = _generate_file_registry(metadata) SAMPLE_DATA = pooch.create( @@ -151,19 +152,19 @@ def _generate_file_registry(metadata: dict[str, dict]) -> dict[str, str]: def list_datasets() -> list[str]: - """Find available sample datasets. + """List available sample datasets. Returns ------- filenames : list of str - List of filenames for available pose data. + List of filenames for available sample datasets. """ return list(metadata.keys()) -def fetch_dataset_paths(filename: str) -> dict: - """Get paths to sample pose data and any associated frames or videos. +def fetch_dataset_paths(filename: str, with_video: bool = False) -> dict: + """Get paths to sample dataset and any associated frames or videos. The data are downloaded from the ``movement`` data repository to the user's local machine upon first use and are stored in a local cache directory. @@ -172,64 +173,86 @@ def fetch_dataset_paths(filename: str) -> dict: Parameters ---------- filename : str - Name of the pose file to fetch. + Name of the sample data file to fetch. + with_video : bool, optional + Whether to download the associated video file (if available). If set + to False, the "video" entry in the returned dictionary will be None. + Defaults to False. Returns ------- paths : dict Dictionary mapping file types to their respective paths. The possible - file types are: "poses", "frame", "video". If "frame" or "video" are - not available, the corresponding value is None. + file types are: "poses" or "bboxes" (depending on tracking type), + "frame", "video". A None value for "frame" or "video" indicates that + the file is either not available or not requested + (if ``with_video=False``). Examples -------- + Fetch a sample dataset and get the paths to the file containing the + predicted poses, as well as the associated frame and video files: + >>> from movement.sample_data import fetch_dataset_paths - >>> paths = fetch_dataset_paths("DLC_single-mouse_EPM.predictions.h5") + >>> paths = fetch_dataset_paths( + ... "DLC_single-mouse_EPM.predictions.h5", with_video=True + ... ) >>> poses_path = paths["poses"] >>> frame_path = paths["frame"] >>> video_path = paths["video"] + If the sample dataset contains bounding boxes instead of + poses, use ``paths["bboxes"]`` instead of ``paths["poses"]``: + + >>> paths = fetch_dataset_paths("VIA_multiple-crabs_5-frames_labels.csv") + >>> bboxes_path = paths["bboxes"] + + See Also -------- fetch_dataset """ - available_pose_files = list_datasets() - if filename not in available_pose_files: + available_data_files = list_datasets() + if filename not in available_data_files: raise log_error( ValueError, f"File '{filename}' is not in the registry. " - f"Valid filenames are: {available_pose_files}", + f"Valid filenames are: {available_data_files}", ) frame_file_name = metadata[filename]["frame"]["file_name"] video_file_name = metadata[filename]["video"]["file_name"] - - return { - "poses": Path( - SAMPLE_DATA.fetch(f"poses/{filename}", progressbar=True) - ), + paths_dict = { "frame": None if not frame_file_name else Path( SAMPLE_DATA.fetch(f"frames/{frame_file_name}", progressbar=True) ), "video": None - if not video_file_name + if (not video_file_name) or not (with_video) else Path( SAMPLE_DATA.fetch(f"videos/{video_file_name}", progressbar=True) ), } + # Add trajectory data + # Assume "poses" if not of type "bboxes" + data_type = "bboxes" if metadata[filename]["type"] == "bboxes" else "poses" + paths_dict[data_type] = Path( + SAMPLE_DATA.fetch(f"{data_type}/{filename}", progressbar=True) + ) + return paths_dict def fetch_dataset( filename: str, + with_video: bool = False, ) -> xarray.Dataset: - """Load a sample dataset containing pose data. + """Load a sample dataset. The data are downloaded from the ``movement`` data repository to the user's local machine upon first use and are stored in a local cache directory. - This function returns the pose data as an xarray Dataset. + This function returns the data as an xarray Dataset. If there are any associated frames or videos, these files are also downloaded and the paths are stored as dataset attributes. @@ -237,16 +260,25 @@ def fetch_dataset( ---------- filename : str Name of the file to fetch. + with_video : bool, optional + Whether to download the associated video file (if available). If set + to False, the "video" entry in the returned dictionary will be None. + Defaults to False. Returns ------- ds : xarray.Dataset - Pose data contained in the fetched sample file. + Data contained in the fetched sample file. Examples -------- + Fetch a sample dataset and get the paths to the associated frame and video + files: + >>> from movement.sample_data import fetch_dataset - >>> ds = fetch_dataset("DLC_single-mouse_EPM.predictions.h5") + >>> ds = fetch_dataset( + "DLC_single-mouse_EPM.predictions.h5", with_video=True + ) >>> frame_path = ds.video_path >>> video_path = ds.frame_path @@ -255,13 +287,18 @@ def fetch_dataset( fetch_dataset_paths """ - file_paths = fetch_dataset_paths(filename) + file_paths = fetch_dataset_paths(filename, with_video=with_video) + + for key, load_module in zip( + ["poses", "bboxes"], [load_poses, load_bboxes], strict=False + ): + if file_paths.get(key): + ds = load_module.from_file( + file_paths[key], + source_software=metadata[filename]["source_software"], + fps=metadata[filename]["fps"], + ) - ds = load_poses.from_file( - file_paths["poses"], - source_software=metadata[filename]["source_software"], - fps=metadata[filename]["fps"], - ) ds.attrs["frame_path"] = file_paths["frame"] ds.attrs["video_path"] = file_paths["video"] diff --git a/movement/utils/__init__.py b/movement/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/movement/logging.py b/movement/utils/logging.py similarity index 75% rename from movement/logging.py rename to movement/utils/logging.py index 9311711d..0174e5ff 100644 --- a/movement/logging.py +++ b/movement/utils/logging.py @@ -1,6 +1,8 @@ """Logging utilities for the movement package.""" import logging +from datetime import datetime +from functools import wraps from logging.handlers import RotatingFileHandler from pathlib import Path @@ -105,3 +107,34 @@ def log_warning(message: str, logger_name: str = "movement"): """ logger = logging.getLogger(logger_name) logger.warning(message) + + +def log_to_attrs(func): + """Log the operation performed by the wrapped function. + + This decorator appends log entries to the data's ``log`` + attribute. The wrapped function must accept an :class:`xarray.Dataset` + or :class:`xarray.DataArray` as its first argument and return an + object of the same type. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + result = func(*args, **kwargs) + + log_entry = { + "operation": func.__name__, + "datetime": str(datetime.now()), + **{f"arg_{i}": arg for i, arg in enumerate(args[1:], start=1)}, + **kwargs, + } + + # Append the log entry to the result's attributes + if result is not None and hasattr(result, "attrs"): + if "log" not in result.attrs: + result.attrs["log"] = [] + result.attrs["log"].append(log_entry) + + return result + + return wrapper diff --git a/movement/utils/reports.py b/movement/utils/reports.py new file mode 100644 index 00000000..416baec1 --- /dev/null +++ b/movement/utils/reports.py @@ -0,0 +1,95 @@ +"""Utility functions for reporting missing data.""" + +import logging + +import xarray as xr + +logger = logging.getLogger(__name__) + + +def calculate_nan_stats( + data: xr.DataArray, + keypoint: str | None = None, + individual: str | None = None, +) -> str: + """Calculate NaN stats for a given keypoint and individual. + + This function calculates the number and percentage of NaN points + for a given keypoint and individual in the input data. A keypoint + is considered NaN if any of its ``space`` coordinates are NaN. + + Parameters + ---------- + data : xarray.DataArray + The input data containing ``keypoints`` and ``individuals`` + dimensions. + keypoint : str, optional + The name of the keypoint for which to generate the report. + If ``None``, it is assumed that the input data contains only + one keypoint and this keypoint is used. + Default is ``None``. + individual : str, optional + The name of the individual for which to generate the report. + If ``None``, it is assumed that the input data contains only + one individual and this individual is used. + Default is ``None``. + + Returns + ------- + str + A string containing the report. + + """ + selection_criteria = {} + if individual is not None: + selection_criteria["individuals"] = individual + if keypoint is not None: + selection_criteria["keypoints"] = keypoint + selected_data = ( + data.sel(**selection_criteria) if selection_criteria else data + ) + n_nans = selected_data.isnull().any(["space"]).sum(["time"]).item() + n_points = selected_data.time.size + percent_nans = round((n_nans / n_points) * 100, 1) + return f"\n\t\t{keypoint}: {n_nans}/{n_points} ({percent_nans}%)" + + +def report_nan_values(da: xr.DataArray, label: str | None = None) -> str: + """Report the number and percentage of keypoints that are NaN. + + Numbers are reported for each individual and keypoint in the data. + + Parameters + ---------- + da : xarray.DataArray + The input data containing ``keypoints`` and ``individuals`` + dimensions. + label : str, optional + Label to identify the data in the report. If not provided, + the name of the DataArray is used as the label. + Default is ``None``. + + Returns + ------- + str + A string containing the report. + + """ + # Compile the report + label = label or da.name + nan_report = f"\nMissing points (marked as NaN) in {label}" + # Check if the data has individuals and keypoints dimensions + has_individuals_dim = "individuals" in da.dims + has_keypoints_dim = "keypoints" in da.dims + # Default values for individuals and keypoints + individuals = da.individuals.values if has_individuals_dim else [None] + keypoints = da.keypoints.values if has_keypoints_dim else [None] + + for ind in individuals: + ind_name = ind if ind is not None else da.individuals.item() + nan_report += f"\n\tIndividual: {ind_name}" + for kp in keypoints: + nan_report += calculate_nan_stats(da, keypoint=kp, individual=ind) + # Write nan report to logger + logger.info(nan_report) + return nan_report diff --git a/movement/utils/vector.py b/movement/utils/vector.py index a30b5658..c91e43ec 100644 --- a/movement/utils/vector.py +++ b/movement/utils/vector.py @@ -3,7 +3,95 @@ import numpy as np import xarray as xr -from movement.logging import log_error +from movement.utils.logging import log_error +from movement.validators.arrays import validate_dims_coords + + +def compute_norm(data: xr.DataArray) -> xr.DataArray: + """Compute the norm of the vectors along the spatial dimension. + + The norm of a vector is its magnitude, also called Euclidean norm, 2-norm + or Euclidean length. Note that if the input data is expressed in polar + coordinates, the magnitude of a vector is the same as its radial coordinate + ``rho``. + + Parameters + ---------- + data : xarray.DataArray + The input data array containing either ``space`` or ``space_pol`` + as a dimension. + + Returns + ------- + xarray.DataArray + A data array holding the norm of the input vectors. + Note that this output array has no spatial dimension but preserves + all other dimensions of the input data array (see Notes). + + Notes + ----- + If the input data array is a ``position`` array, this function will compute + the magnitude of the position vectors, for every individual and keypoint, + at every timestep. If the input data array is a ``shape`` array of a + bounding boxes dataset, it will compute the magnitude of the shape + vectors (i.e., the diagonal of the bounding box), + for every individual and at every timestep. + + + """ + if "space" in data.dims: + validate_dims_coords(data, {"space": ["x", "y"]}) + return xr.apply_ufunc( + np.linalg.norm, + data, + input_core_dims=[["space"]], + kwargs={"axis": -1}, + ) + elif "space_pol" in data.dims: + validate_dims_coords(data, {"space_pol": ["rho", "phi"]}) + return data.sel(space_pol="rho", drop=True) + else: + _raise_error_for_missing_spatial_dim() + + +def convert_to_unit(data: xr.DataArray) -> xr.DataArray: + """Convert the vectors along the spatial dimension into unit vectors. + + A unit vector is a vector pointing in the same direction as the original + vector but with norm = 1. + + Parameters + ---------- + data : xarray.DataArray + The input data array containing either ``space`` or ``space_pol`` + as a dimension. + + Returns + ------- + xarray.DataArray + A data array holding the unit vectors of the input data array + (all input dimensions are preserved). + + Notes + ----- + Note that the unit vector for the null vector is undefined, since the null + vector has 0 norm and no direction associated with it. + + """ + if "space" in data.dims: + validate_dims_coords(data, {"space": ["x", "y"]}) + return data / compute_norm(data) + elif "space_pol" in data.dims: + validate_dims_coords(data, {"space_pol": ["rho", "phi"]}) + # Set both rho and phi values to NaN at null vectors (where rho = 0) + new_data = xr.where(data.sel(space_pol="rho") == 0, np.nan, data) + # Set the rho values to 1 for non-null vectors (phi is preserved) + new_data.loc[{"space_pol": "rho"}] = xr.where( + new_data.sel(space_pol="rho").isnull(), np.nan, 1 + ) + return new_data + else: + _raise_error_for_missing_spatial_dim() def cart2pol(data: xr.DataArray) -> xr.DataArray: @@ -24,13 +112,8 @@ def cart2pol(data: xr.DataArray) -> xr.DataArray: ``phi`` returned are in radians, in the range ``[-pi, pi]``. """ - _validate_dimension_coordinates(data, {"space": ["x", "y"]}) - rho = xr.apply_ufunc( - np.linalg.norm, - data, - input_core_dims=[["space"]], - kwargs={"axis": -1}, - ) + validate_dims_coords(data, {"space": ["x", "y"]}) + rho = compute_norm(data) phi = xr.apply_ufunc( np.arctan2, data.sel(space="y"), @@ -65,7 +148,7 @@ def pol2cart(data: xr.DataArray) -> xr.DataArray: in the dimension coordinate. """ - _validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]}) + validate_dims_coords(data, {"space_pol": ["rho", "phi"]}) rho = data.sel(space_pol="rho") phi = data.sel(space_pol="phi") x = rho * np.cos(phi) @@ -82,43 +165,9 @@ def pol2cart(data: xr.DataArray) -> xr.DataArray: ).transpose(*dims) -def _validate_dimension_coordinates( - data: xr.DataArray, required_dim_coords: dict -) -> None: - """Validate the input data array. - - Ensure that it contains the required dimensions and coordinates. - - Parameters - ---------- - data : xarray.DataArray - The input data to validate. - required_dim_coords : dict - A dictionary of required dimensions and their corresponding - coordinate values. - - Raises - ------ - ValueError - If the input data does not contain the required dimension(s) - and/or the required coordinate(s). - - """ - missing_dims = [dim for dim in required_dim_coords if dim not in data.dims] - error_message = "" - if missing_dims: - error_message += ( - f"Input data must contain {missing_dims} as dimensions.\n" - ) - missing_coords = [] - for dim, coords in required_dim_coords.items(): - missing_coords = [ - coord for coord in coords if coord not in data.coords.get(dim, []) - ] - if missing_coords: - error_message += ( - "Input data must contain " - f"{missing_coords} in the '{dim}' coordinates." - ) - if error_message: - raise log_error(ValueError, error_message) +def _raise_error_for_missing_spatial_dim() -> None: + raise log_error( + ValueError, + "Input data array must contain either 'space' or 'space_pol' " + "as dimensions.", + ) diff --git a/movement/validators/__init__.py b/movement/validators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/movement/validators/arrays.py b/movement/validators/arrays.py new file mode 100644 index 00000000..76847571 --- /dev/null +++ b/movement/validators/arrays.py @@ -0,0 +1,61 @@ +"""Validators for data arrays.""" + +import xarray as xr + +from movement.utils.logging import log_error + + +def validate_dims_coords( + data: xr.DataArray, required_dim_coords: dict +) -> None: + """Validate dimensions and coordinates in a data array. + + This function raises a ValueError if the specified dimensions and + coordinates are not present in the input data array. + + Parameters + ---------- + data : xarray.DataArray + The input data array to validate. + required_dim_coords : dict + A dictionary of required dimensions and their corresponding + coordinate values. If you don't need to specify any + coordinate values, you can pass an empty list. + + Examples + -------- + Validate that a data array contains the dimension 'time'. No specific + coordinates are required. + + >>> validate_dims_coords(data, {"time": []}) + + Validate that a data array contains the dimensions 'time' and 'space', + and that the 'space' dimension contains the coordinates 'x' and 'y'. + + >>> validate_dims_coords(data, {"time": [], "space": ["x", "y"]}) + + Raises + ------ + ValueError + If the input data does not contain the required dimension(s) + and/or the required coordinate(s). + + """ + missing_dims = [dim for dim in required_dim_coords if dim not in data.dims] + error_message = "" + if missing_dims: + error_message += ( + f"Input data must contain {missing_dims} as dimensions.\n" + ) + missing_coords = [] + for dim, coords in required_dim_coords.items(): + missing_coords = [ + coord for coord in coords if coord not in data.coords.get(dim, []) + ] + if missing_coords: + error_message += ( + "Input data must contain " + f"{missing_coords} in the '{dim}' coordinates." + ) + if error_message: + raise log_error(ValueError, error_message) diff --git a/movement/validators/datasets.py b/movement/validators/datasets.py new file mode 100644 index 00000000..99a68c10 --- /dev/null +++ b/movement/validators/datasets.py @@ -0,0 +1,396 @@ +"""``attrs`` classes for validating data structures.""" + +from collections.abc import Iterable +from typing import Any, ClassVar + +import attrs +import numpy as np +from attrs import converters, define, field, validators + +from movement.utils.logging import log_error, log_warning + + +def _convert_to_list_of_str(value: str | Iterable[Any]) -> list[str]: + """Try to coerce the value into a list of strings.""" + if isinstance(value, str): + log_warning( + f"Invalid value ({value}). Expected a list of strings. " + "Converting to a list of length 1." + ) + return [value] + elif isinstance(value, Iterable): + return [str(item) for item in value] + else: + raise log_error( + ValueError, f"Invalid value ({value}). Expected a list of strings." + ) + + +def _convert_fps_to_none_if_invalid(fps: float | None) -> float | None: + """Set fps to None if a non-positive float is passed.""" + if fps is not None and fps <= 0: + log_warning( + f"Invalid fps value ({fps}). Expected a positive number. " + "Setting fps to None." + ) + return None + return fps + + +def _validate_type_ndarray(value: Any) -> None: + """Raise ValueError the value is a not numpy array.""" + if not isinstance(value, np.ndarray): + raise log_error( + ValueError, f"Expected a numpy array, but got {type(value)}." + ) + + +def _validate_array_shape( + attribute: attrs.Attribute, value: np.ndarray, expected_shape: tuple +): + """Raise ValueError if the value does not have the expected shape.""" + if value.shape != expected_shape: + raise log_error( + ValueError, + f"Expected '{attribute.name}' to have shape {expected_shape}, " + f"but got {value.shape}.", + ) + + +def _validate_list_length( + attribute: attrs.Attribute, value: list | None, expected_length: int +): + """Raise a ValueError if the list does not have the expected length.""" + if (value is not None) and (len(value) != expected_length): + raise log_error( + ValueError, + f"Expected '{attribute.name}' to have length {expected_length}, " + f"but got {len(value)}.", + ) + + +@define(kw_only=True) +class ValidPosesDataset: + """Class for validating poses data intended for a ``movement`` dataset. + + The validator ensures that within the ``movement poses`` dataset: + + - The required ``position_array`` is a numpy array + with the last dimension containing 2 or 3 spatial coordinates. + - The optional ``confidence_array``, if provided, is a numpy array + with its shape matching the first three dimensions of the + ``position_array``; otherwise, it defaults to an array of NaNs. + - The optional ``individual_names`` and ``keypoint_names``, + if provided, match the number of individuals and keypoints + in the dataset, respectively; otherwise, default names are assigned. + - The optional ``fps`` is a positive float; otherwise, it defaults to None. + - The optional ``source_software`` is a string; otherwise, + it defaults to None. + + Attributes + ---------- + position_array : np.ndarray + Array of shape (n_frames, n_individuals, n_keypoints, n_space) + containing the poses. + confidence_array : np.ndarray, optional + Array of shape (n_frames, n_individuals, n_keypoints) containing + the point-wise confidence scores. + If None (default), the scores will be set to an array of NaNs. + individual_names : list of str, optional + List of unique names for the individuals in the video. If None + (default), the individuals will be named "individual_0", + "individual_1", etc. + keypoint_names : list of str, optional + List of unique names for the keypoints in the skeleton. If None + (default), the keypoints will be named "keypoint_0", "keypoint_1", + etc. + fps : float, optional + Frames per second of the video. Defaults to None. + source_software : str, optional + Name of the software from which the poses were loaded. + Defaults to None. + + Raises + ------ + ValueError + If the dataset does not meet the ``movement poses`` + dataset requirements. + + """ + + # Required attributes + position_array: np.ndarray = field() + + # Optional attributes + confidence_array: np.ndarray | None = field(default=None) + individual_names: list[str] | None = field( + default=None, + converter=converters.optional(_convert_to_list_of_str), + ) + keypoint_names: list[str] | None = field( + default=None, + converter=converters.optional(_convert_to_list_of_str), + ) + fps: float | None = field( + default=None, + converter=converters.pipe( # type: ignore + converters.optional(float), _convert_fps_to_none_if_invalid + ), + ) + source_software: str | None = field( + default=None, + validator=validators.optional(validators.instance_of(str)), + ) + + # Class variables + DIM_NAMES: ClassVar[tuple] = ("time", "individuals", "keypoints", "space") + VAR_NAMES: ClassVar[tuple] = ("position", "confidence") + + # Add validators + @position_array.validator + def _validate_position_array(self, attribute, value): + _validate_type_ndarray(value) + if value.ndim != 4: + raise log_error( + ValueError, + f"Expected '{attribute.name}' to have 4 dimensions, " + f"but got {value.ndim}.", + ) + if value.shape[-1] not in [2, 3]: + raise log_error( + ValueError, + f"Expected '{attribute.name}' to have 2 or 3 spatial " + f"dimensions, but got {value.shape[-1]}.", + ) + + @confidence_array.validator + def _validate_confidence_array(self, attribute, value): + if value is not None: + _validate_type_ndarray(value) + _validate_array_shape( + attribute, value, expected_shape=self.position_array.shape[:-1] + ) + + @individual_names.validator + def _validate_individual_names(self, attribute, value): + if self.source_software == "LightningPose": + # LightningPose only supports a single individual + _validate_list_length(attribute, value, 1) + else: + _validate_list_length( + attribute, value, self.position_array.shape[1] + ) + + @keypoint_names.validator + def _validate_keypoint_names(self, attribute, value): + _validate_list_length(attribute, value, self.position_array.shape[2]) + + def __attrs_post_init__(self): + """Assign default values to optional attributes (if None).""" + if self.confidence_array is None: + self.confidence_array = np.full( + (self.position_array.shape[:-1]), np.nan, dtype="float32" + ) + log_warning( + "Confidence array was not provided." + "Setting to an array of NaNs." + ) + if self.individual_names is None: + self.individual_names = [ + f"individual_{i}" for i in range(self.position_array.shape[1]) + ] + log_warning( + "Individual names were not provided. " + f"Setting to {self.individual_names}." + ) + if self.keypoint_names is None: + self.keypoint_names = [ + f"keypoint_{i}" for i in range(self.position_array.shape[2]) + ] + log_warning( + "Keypoint names were not provided. " + f"Setting to {self.keypoint_names}." + ) + + +@define(kw_only=True) +class ValidBboxesDataset: + """Class for validating bounding boxes' data for a ``movement`` dataset. + + The validator considers 2D bounding boxes only. It ensures that + within the ``movement bboxes`` dataset: + + - The required ``position_array`` and ``shape_array`` are numpy arrays, + with the last dimension containing 2 spatial coordinates. + - The optional ``confidence_array``, if provided, is a numpy array + with its shape matching the first two dimensions of the + ``position_array``; otherwise, it defaults to an array of NaNs. + - The optional ``individual_names``, if provided, match the number of + individuals in the dataset; otherwise, default names are assigned. + - The optional ``frame_array``, if provided, is a column vector + with the frame numbers; otherwise, it defaults to an array of + 0-based integers. + - The optional ``fps`` is a positive float; otherwise, it defaults to None. + - The optional ``source_software`` is a string; otherwise, it defaults to + None. + + Attributes + ---------- + position_array : np.ndarray + Array of shape (n_frames, n_individuals, n_space) + containing the tracks of the bounding boxes' centroids. + shape_array : np.ndarray + Array of shape (n_frames, n_individuals, n_space) + containing the shape of the bounding boxes. The shape of a bounding + box is its width (extent along the x-axis of the image) and height + (extent along the y-axis of the image). + confidence_array : np.ndarray, optional + Array of shape (n_frames, n_individuals) containing + the confidence scores of the bounding boxes. If None (default), the + confidence scores are set to an array of NaNs. + individual_names : list of str, optional + List of individual names for the tracked bounding boxes in the video. + If None (default), bounding boxes are assigned names based on the size + of the ``position_array``. The names will be in the format of + ``id_``, where is an integer from 0 to + ``position_array.shape[1]-1``. + frame_array : np.ndarray, optional + Array of shape (n_frames, 1) containing the frame numbers for which + bounding boxes are defined. If None (default), frame numbers will + be assigned based on the first dimension of the ``position_array``, + starting from 0. + fps : float, optional + Frames per second defining the sampling rate of the data. + Defaults to None. + source_software : str, optional + Name of the software that generated the data. Defaults to None. + + Raises + ------ + ValueError + If the dataset does not meet the ``movement bboxes`` dataset + requirements. + + """ + + # Required attributes + position_array: np.ndarray = field() + shape_array: np.ndarray = field() + + # Optional attributes + confidence_array: np.ndarray | None = field(default=None) + individual_names: list[str] | None = field( + default=None, + converter=converters.optional( + _convert_to_list_of_str + ), # force into list of strings if not + ) + frame_array: np.ndarray | None = field(default=None) + fps: float | None = field( + default=None, + converter=converters.pipe( # type: ignore + converters.optional(float), _convert_fps_to_none_if_invalid + ), + ) + source_software: str | None = field( + default=None, + validator=validators.optional(validators.instance_of(str)), + ) + + DIM_NAMES: ClassVar[tuple] = ("time", "individuals", "space") + VAR_NAMES: ClassVar[tuple] = ("position", "shape", "confidence") + + # Validators + @position_array.validator + @shape_array.validator + def _validate_position_and_shape_arrays(self, attribute, value): + _validate_type_ndarray(value) + # check last dimension (spatial) has 2 coordinates + n_expected_spatial_coordinates = 2 + if value.shape[-1] != n_expected_spatial_coordinates: + raise log_error( + ValueError, + f"Expected '{attribute.name}' to have 2 spatial coordinates, " + f"but got {value.shape[-1]}.", + ) + + @individual_names.validator + def _validate_individual_names(self, attribute, value): + if value is not None: + _validate_list_length( + attribute, value, self.position_array.shape[1] + ) + # check n_individual_names are unique + # NOTE: combined with the requirement above, we are enforcing + # unique IDs per frame + if len(value) != len(set(value)): + raise log_error( + ValueError, + "individual_names passed to the dataset are not unique. " + f"There are {len(value)} elements in the list, but " + f"only {len(set(value))} are unique.", + ) + + @confidence_array.validator + def _validate_confidence_array(self, attribute, value): + if value is not None: + _validate_type_ndarray(value) + _validate_array_shape( + attribute, value, expected_shape=self.position_array.shape[:-1] + ) + + @frame_array.validator + def _validate_frame_array(self, attribute, value): + if value is not None: + _validate_type_ndarray(value) + # should be a column vector (n_frames, 1) + _validate_array_shape( + attribute, + value, + expected_shape=(self.position_array.shape[0], 1), + ) + # check frames are continuous: exactly one frame number per row + if not np.all(np.diff(value, axis=0) == 1): + raise log_error( + ValueError, + f"Frame numbers in {attribute.name} are not continuous.", + ) + + # Define defaults + def __attrs_post_init__(self): + """Assign default values to optional attributes (if None). + + If no confidence_array is provided, set it to an array of NaNs. + If no individual names are provided, assign them unique IDs per frame, + starting with 0 ("id_0"). + """ + # assign default confidence_array + if self.confidence_array is None: + self.confidence_array = np.full( + (self.position_array.shape[:-1]), + np.nan, + dtype="float32", + ) + log_warning( + "Confidence array was not provided. " + "Setting to an array of NaNs." + ) + # assign default individual_names + if self.individual_names is None: + self.individual_names = [ + f"id_{i}" for i in range(self.position_array.shape[1]) + ] + log_warning( + "Individual names for the bounding boxes " + "were not provided. " + "Setting to 0-based IDs that are unique per frame: \n" + f"{self.individual_names}.\n" + ) + # assign default frame_array + if self.frame_array is None: + n_frames = self.position_array.shape[0] + self.frame_array = np.arange(n_frames).reshape(-1, 1) + log_warning( + "Frame numbers were not provided. " + "Setting to an array of 0-based integers." + ) diff --git a/movement/validators/files.py b/movement/validators/files.py new file mode 100644 index 00000000..8d013a95 --- /dev/null +++ b/movement/validators/files.py @@ -0,0 +1,435 @@ +"""``attrs`` classes for validating file paths.""" + +import ast +import os +import re +from pathlib import Path +from typing import Literal + +import h5py +import pandas as pd +from attrs import define, field, validators + +from movement.utils.logging import log_error + + +@define +class ValidFile: + """Class for validating file paths. + + The validator ensures that the file: + + - is not a directory, + - exists if it is meant to be read, + - does not exist if it is meant to be written, + - has the expected access permission(s), and + - has one of the expected suffix(es). + + Attributes + ---------- + path : str or pathlib.Path + Path to the file. + expected_permission : {"r", "w", "rw"} + Expected access permission(s) for the file. If "r", the file is + expected to be readable. If "w", the file is expected to be writable. + If "rw", the file is expected to be both readable and writable. + Default: "r". + expected_suffix : list of str + Expected suffix(es) for the file. If an empty list (default), this + check is skipped. + + Raises + ------ + IsADirectoryError + If the path points to a directory. + PermissionError + If the file does not have the expected access permission(s). + FileNotFoundError + If the file does not exist when ``expected_permission`` is "r" or "rw". + FileExistsError + If the file exists when ``expected_permission`` is "w". + ValueError + If the file does not have one of the expected suffix(es). + + """ + + path: Path = field(converter=Path, validator=validators.instance_of(Path)) + expected_permission: Literal["r", "w", "rw"] = field( + default="r", validator=validators.in_(["r", "w", "rw"]), kw_only=True + ) + expected_suffix: list[str] = field(factory=list, kw_only=True) + + @path.validator + def _path_is_not_dir(self, attribute, value): + """Ensure that the path does not point to a directory.""" + if value.is_dir(): + raise log_error( + IsADirectoryError, + f"Expected a file path but got a directory: {value}.", + ) + + @path.validator + def _file_exists_when_expected(self, attribute, value): + """Ensure that the file exists (or not) as needed. + + This depends on the expected usage (read and/or write). + """ + if "r" in self.expected_permission: + if not value.exists(): + raise log_error( + FileNotFoundError, f"File {value} does not exist." + ) + else: # expected_permission is "w" + if value.exists(): + raise log_error( + FileExistsError, f"File {value} already exists." + ) + + @path.validator + def _file_has_access_permissions(self, attribute, value): + """Ensure that the file has the expected access permission(s). + + Raises a PermissionError if not. + """ + file_is_readable = os.access(value, os.R_OK) + parent_is_writeable = os.access(value.parent, os.W_OK) + if ("r" in self.expected_permission) and (not file_is_readable): + raise log_error( + PermissionError, + f"Unable to read file: {value}. " + "Make sure that you have read permissions.", + ) + if ("w" in self.expected_permission) and (not parent_is_writeable): + raise log_error( + PermissionError, + f"Unable to write to file: {value}. " + "Make sure that you have write permissions.", + ) + + @path.validator + def _file_has_expected_suffix(self, attribute, value): + """Ensure that the file has one of the expected suffix(es).""" + if self.expected_suffix and value.suffix not in self.expected_suffix: + raise log_error( + ValueError, + f"Expected file with suffix(es) {self.expected_suffix} " + f"but got suffix {value.suffix} instead.", + ) + + +@define +class ValidHDF5: + """Class for validating HDF5 files. + + The validator ensures that the file: + + - is in HDF5 format, and + - contains the expected datasets. + + Attributes + ---------- + path : pathlib.Path + Path to the HDF5 file. + expected_datasets : list of str or None + List of names of the expected datasets in the HDF5 file. If an empty + list (default), this check is skipped. + + Raises + ------ + ValueError + If the file is not in HDF5 format or if it does not contain the + expected datasets. + + """ + + path: Path = field(validator=validators.instance_of(Path)) + expected_datasets: list[str] = field(factory=list, kw_only=True) + + @path.validator + def _file_is_h5(self, attribute, value): + """Ensure that the file is indeed in HDF5 format.""" + try: + with h5py.File(value, "r") as f: + f.close() + except Exception as e: + raise log_error( + ValueError, + f"File {value} does not seem to be in valid" "HDF5 format.", + ) from e + + @path.validator + def _file_contains_expected_datasets(self, attribute, value): + """Ensure that the HDF5 file contains the expected datasets.""" + if self.expected_datasets: + with h5py.File(value, "r") as f: + diff = set(self.expected_datasets).difference(set(f.keys())) + if len(diff) > 0: + raise log_error( + ValueError, + f"Could not find the expected dataset(s) {diff} " + f"in file: {value}. ", + ) + + +@define +class ValidDeepLabCutCSV: + """Class for validating DeepLabCut-style .csv files. + + The validator ensures that the file contains the + expected index column levels. + + Attributes + ---------- + path : pathlib.Path + Path to the .csv file. + + Raises + ------ + ValueError + If the .csv file does not contain the expected DeepLabCut index column + levels among its top rows. + + """ + + path: Path = field(validator=validators.instance_of(Path)) + + @path.validator + def _file_contains_expected_levels(self, attribute, value): + """Ensure that the .csv file contains the expected index column levels. + + These are to be found among the top 4 rows of the file. + """ + expected_levels = ["scorer", "bodyparts", "coords"] + + with open(value) as f: + top4_row_starts = [f.readline().split(",")[0] for _ in range(4)] + + if top4_row_starts[3].isdigit(): + # if 4th row starts with a digit, assume single-animal DLC file + expected_levels.append(top4_row_starts[3]) + else: + # otherwise, assume multi-animal DLC file + expected_levels.insert(1, "individuals") + + if top4_row_starts != expected_levels: + raise log_error( + ValueError, + ".csv header rows do not match the known format for " + "DeepLabCut pose estimation output files.", + ) + + +@define +class ValidVIATracksCSV: + """Class for validating VIA tracks .csv files. + + The validator ensures that the file: + + - contains the expected header, + - contains valid frame numbers, + - contains tracked bounding boxes, and + - defines bounding boxes whose IDs are unique per image file. + + Attributes + ---------- + path : pathlib.Path + Path to the VIA tracks .csv file. + + Raises + ------ + ValueError + If the file does not match the VIA tracks .csv file requirements. + + """ + + path: Path = field(validator=validators.instance_of(Path)) + + @path.validator + def _file_contains_valid_header(self, attribute, value): + """Ensure the VIA tracks .csv file contains the expected header.""" + expected_header = [ + "filename", + "file_size", + "file_attributes", + "region_count", + "region_id", + "region_shape_attributes", + "region_attributes", + ] + + with open(value) as f: + header = f.readline().strip("\n").split(",") + + if header != expected_header: + raise log_error( + ValueError, + ".csv header row does not match the known format for " + "VIA tracks .csv files. " + f"Expected {expected_header} but got {header}.", + ) + + @path.validator + def _file_contains_valid_frame_numbers(self, attribute, value): + """Ensure that the VIA tracks .csv file contains valid frame numbers. + + This involves: + + - Checking that frame numbers are included in ``file_attributes`` or + encoded in the image file ``filename``. + - Checking the frame number can be cast as an integer. + - Checking that there are as many unique frame numbers as unique image + files. + + If the frame number is included as part of the image file name, then + it is expected as an integer led by at least one zero, between "_" and + ".", followed by the file extension. + + """ + df = pd.read_csv(value, sep=",", header=0) + + # Extract list of file attributes (dicts) + file_attributes_dicts = [ + ast.literal_eval(d) for d in df.file_attributes + ] + + # If 'frame' is a file_attribute for all files: + # extract frame number + list_frame_numbers = [] + if all(["frame" in d for d in file_attributes_dicts]): + for k_i, k in enumerate(file_attributes_dicts): + try: + list_frame_numbers.append(int(k["frame"])) + except Exception as e: + raise log_error( + ValueError, + f"{df.filename.iloc[k_i]} (row {k_i}): " + "'frame' file attribute cannot be cast as an integer. " + f"Please review the file attributes: {k}.", + ) from e + + # else: extract frame number from filename. + else: + pattern = r"_(0\d*)\.\w+$" + + for f_i, f in enumerate(df["filename"]): + regex_match = re.search(pattern, f) + if regex_match: # if there is a pattern match + list_frame_numbers.append( + int(regex_match.group(1)) # type: ignore + # the match will always be castable as integer + ) + else: + raise log_error( + ValueError, + f"{f} (row {f_i}): " + "a frame number could not be extracted from the " + "filename. If included in the filename, the frame " + "number is expected as a zero-padded integer between " + "an underscore '_' and the file extension " + "(e.g. img_00234.png).", + ) + + # Check we have as many unique frame numbers as unique image files + if len(set(list_frame_numbers)) != len(df.filename.unique()): + raise log_error( + ValueError, + "The number of unique frame numbers does not match the number " + "of unique image files. Please review the VIA tracks .csv " + "file and ensure a unique frame number is defined for each " + "file. ", + ) + + @path.validator + def _file_contains_tracked_bboxes(self, attribute, value): + """Ensure that the VIA tracks .csv contains tracked bounding boxes. + + This involves: + + - Checking that the bounding boxes are defined as rectangles. + - Checking that the bounding boxes have all geometric parameters + (``["x", "y", "width", "height"]``). + - Checking that the bounding boxes have a track ID defined. + - Checking that the track ID can be cast as an integer. + """ + df = pd.read_csv(value, sep=",", header=0) + + for row in df.itertuples(): + row_region_shape_attrs = ast.literal_eval( + row.region_shape_attributes + ) + row_region_attrs = ast.literal_eval(row.region_attributes) + + # check annotation is a rectangle + if row_region_shape_attrs["name"] != "rect": + raise log_error( + ValueError, + f"{row.filename} (row {row.Index}): " + "bounding box shape must be 'rect' (rectangular) " + "but instead got " + f"'{row_region_shape_attrs['name']}'.", + ) + + # check all geometric parameters for the box are defined + if not all( + [ + key in row_region_shape_attrs + for key in ["x", "y", "width", "height"] + ] + ): + raise log_error( + ValueError, + f"{row.filename} (row {row.Index}): " + f"at least one bounding box shape parameter is missing. " + "Expected 'x', 'y', 'width', 'height' to exist as " + "'region_shape_attributes', but got " + f"'{list(row_region_shape_attrs.keys())}'.", + ) + + # check track ID is defined + if "track" not in row_region_attrs: + raise log_error( + ValueError, + f"{row.filename} (row {row.Index}): " + "bounding box does not have a 'track' attribute defined " + "under 'region_attributes'. " + "Please review the VIA tracks .csv file.", + ) + + # check track ID is castable as an integer + try: + int(row_region_attrs["track"]) + except Exception as e: + raise log_error( + ValueError, + f"{row.filename} (row {row.Index}): " + "the track ID for the bounding box cannot be cast " + "as an integer. Please review the VIA tracks .csv file.", + ) from e + + @path.validator + def _file_contains_unique_track_ids_per_filename(self, attribute, value): + """Ensure the VIA tracks .csv contains unique track IDs per filename. + + It checks that bounding boxes IDs are defined once per image file. + """ + df = pd.read_csv(value, sep=",", header=0) + + list_unique_filenames = list(set(df.filename)) + for file in list_unique_filenames: + df_one_filename = df.loc[df["filename"] == file] + + list_track_ids_one_filename = [ + int(ast.literal_eval(row.region_attributes)["track"]) + for row in df_one_filename.itertuples() + ] + + if len(set(list_track_ids_one_filename)) != len( + list_track_ids_one_filename + ): + raise log_error( + ValueError, + f"{file}: " + "multiple bounding boxes in this file " + "have the same track ID. " + "Please review the VIA tracks .csv file.", + ) diff --git a/pyproject.toml b/pyproject.toml index d97bc3c4..27348c29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,12 +1,13 @@ [project] name = "movement" authors = [ - { name = "Niko Sirmpilatze", email = "niko.sirbiladze@gmail.com" }, + { name = "Nikoloz Sirmpilatze", email = "niko.sirbiladze@gmail.com" }, { name = "Chang Huan Lo", email = "changhuan.lo@ucl.ac.uk" }, + { name = "Sofía Miñano", email = "s.minano@ucl.ac.uk" }, ] description = "Analysis of body movement" readme = "README.md" -requires-python = ">=3.9.0" +requires-python = ">=3.10.0" dynamic = ["version"] license = { text = "BSD-3-Clause" } @@ -27,9 +28,9 @@ classifiers = [ "Development Status :: 3 - Alpha", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Operating System :: OS Independent", "License :: OSI Approved :: BSD License", ] @@ -106,20 +107,23 @@ ignore = [ "D213", # multi-line-summary second line ] select = [ - "E", # pycodestyle errors - "F", # Pyflakes - "UP", # pyupgrade - "I", # isort - "B", # flake8 bugbear - "SIM", # flake8 simplify - "C90", # McCabe complexity - "D", # pydocstyle + "E", # pycodestyle errors + "F", # Pyflakes + "UP", # pyupgrade + "I", # isort + "B", # flake8 bugbear + "SIM", # flake8 simplify + "C90", # McCabe complexity + "D", # pydocstyle + "NPY201", # checks for syntax that was deprecated in numpy2.0 ] per-file-ignores = { "tests/*" = [ "D100", # missing docstring in public module "D205", # missing blank line between summary and description "D103", # missing docstring in public function ], "examples/*" = [ + "B018", # Found useless expression + "D103", # Missing docstring in public function "D400", # first line should end with a period. "D415", # first line should end with a period, question mark... "D205", # missing blank line between summary and description @@ -135,15 +139,17 @@ check-hidden = true [tool.tox] legacy_tox_ini = """ [tox] -requires = tox-conda -envlist = py{39,310,311} +requires = + tox-conda + tox-gh-actions +envlist = py{310,311,312} isolated_build = True [gh-actions] python = - 3.9: py39 3.10: py310 3.11: py311 + 3.12: py312 [testenv] conda_deps = diff --git a/tests/conftest.py b/tests/conftest.py index 9acc418e..e55b0c6a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,19 +11,20 @@ import pytest import xarray as xr -from movement import MovementDataset -from movement.logging import configure_logging from movement.sample_data import fetch_dataset_paths, list_datasets +from movement.utils.logging import configure_logging +from movement.validators.datasets import ValidBboxesDataset, ValidPosesDataset def pytest_configure(): """Perform initial configuration for pytest. Fetches pose data file paths as a dictionary for tests. """ - pytest.POSE_DATA_PATHS = { - file_name: fetch_dataset_paths(file_name)["poses"] - for file_name in list_datasets() - } + pytest.DATA_PATHS = {} + for file_name in list_datasets(): + paths_dict = fetch_dataset_paths(file_name) + data_path = paths_dict.get("poses") or paths_dict.get("bboxes") + pytest.DATA_PATHS[file_name] = data_path @pytest.fixture(autouse=True) @@ -38,6 +39,7 @@ def setup_logging(tmp_path): ) +# --------- File validator fixtures --------------------------------- @pytest.fixture def unreadable_file(tmp_path): """Return a dictionary containing the file path and @@ -194,9 +196,7 @@ def new_csv_file(tmp_path): @pytest.fixture def dlc_style_df(): """Return a valid DLC-style DataFrame.""" - return pd.read_hdf( - pytest.POSE_DATA_PATHS.get("DLC_single-wasp.predictions.h5") - ) + return pd.read_hdf(pytest.DATA_PATHS.get("DLC_single-wasp.predictions.h5")) @pytest.fixture( @@ -211,9 +211,140 @@ def dlc_style_df(): ) def sleap_file(request): """Return the file path for a SLEAP .h5 or .slp file.""" - return pytest.POSE_DATA_PATHS.get(request.param) + return pytest.DATA_PATHS.get(request.param) + + +# ------------ Dataset validator fixtures --------------------------------- + + +@pytest.fixture +def valid_bboxes_arrays_all_zeros(): + """Return a dictionary of valid zero arrays (in terms of shape) for a + ValidBboxesDataset. + """ + # define the shape of the arrays + n_frames, n_individuals, n_space = (10, 2, 2) + + # build a valid array for position or shape with all zeros + valid_bbox_array_all_zeros = np.zeros((n_frames, n_individuals, n_space)) + + # return as a dict + return { + "position": valid_bbox_array_all_zeros, + "shape": valid_bbox_array_all_zeros, + "individual_names": ["id_" + str(id) for id in range(n_individuals)], + } + + +# --------------------- Bboxes dataset fixtures ---------------------------- +@pytest.fixture +def valid_bboxes_arrays(): + """Return a dictionary of valid arrays for a + ValidBboxesDataset representing a uniform linear motion. + + It represents 2 individuals for 10 frames, in 2D space. + - Individual 0 moves along the x=y line from the origin. + - Individual 1 moves along the x=-y line line from the origin. + + All confidence values are set to 0.9 except the following which are set + to 0.1: + - Individual 0 at frames 2, 3, 4 + - Individual 1 at frames 2, 3 + """ + # define the shape of the arrays + n_frames, n_individuals, n_space = (10, 2, 2) + + # build a valid array for position + # make bbox with id_i move along x=((-1)**(i))*y line from the origin + # if i is even: along x = y line + # if i is odd: along x = -y line + # moving one unit along each axis in each frame + position = np.empty((n_frames, n_individuals, n_space)) + for i in range(n_individuals): + position[:, i, 0] = np.arange(n_frames) + position[:, i, 1] = (-1) ** i * np.arange(n_frames) + + # build a valid array for constant bbox shape (60, 40) + constant_shape = (60, 40) # width, height in pixels + shape = np.tile(constant_shape, (n_frames, n_individuals, 1)) + + # build an array of confidence values, all 0.9 + confidence = np.full((n_frames, n_individuals), 0.9) + + # set 5 low-confidence values + # - set 3 confidence values for bbox id_0 to 0.1 + # - set 2 confidence values for bbox id_1 to 0.1 + idx_start = 2 + confidence[idx_start : idx_start + 3, 0] = 0.1 + confidence[idx_start : idx_start + 2, 1] = 0.1 + + return { + "position": position, + "shape": shape, + "confidence": confidence, + } + + +@pytest.fixture +def valid_bboxes_dataset( + valid_bboxes_arrays, +): + """Return a valid bboxes dataset for two individuals moving in uniform + linear motion, with 5 frames with low confidence values and time in frames. + """ + dim_names = ValidBboxesDataset.DIM_NAMES + + position_array = valid_bboxes_arrays["position"] + shape_array = valid_bboxes_arrays["shape"] + confidence_array = valid_bboxes_arrays["confidence"] + + n_frames, n_individuals, _ = position_array.shape + + return xr.Dataset( + data_vars={ + "position": xr.DataArray(position_array, dims=dim_names), + "shape": xr.DataArray(shape_array, dims=dim_names), + "confidence": xr.DataArray(confidence_array, dims=dim_names[:-1]), + }, + coords={ + dim_names[0]: np.arange(n_frames), + dim_names[1]: [f"id_{id}" for id in range(n_individuals)], + dim_names[2]: ["x", "y"], + }, + attrs={ + "fps": None, + "time_unit": "frames", + "source_software": "test", + "source_file": "test_bboxes.csv", + "ds_type": "bboxes", + }, + ) + + +@pytest.fixture +def valid_bboxes_dataset_in_seconds(valid_bboxes_dataset): + """Return a valid bboxes dataset with time in seconds. + + The origin of time is assumed to be time = frame 0 = 0 seconds. + """ + fps = 60 + valid_bboxes_dataset["time"] = valid_bboxes_dataset.time / fps + valid_bboxes_dataset.attrs["time_unit"] = "seconds" + valid_bboxes_dataset.attrs["fps"] = fps + return valid_bboxes_dataset + + +@pytest.fixture +def valid_bboxes_dataset_with_nan(valid_bboxes_dataset): + """Return a valid bboxes dataset with NaN values in the position array.""" + # Set 3 NaN values in the position array for id_0 + valid_bboxes_dataset.position.loc[ + {"individuals": "id_0", "time": [3, 7, 8]} + ] = np.nan + return valid_bboxes_dataset +# --------------------- Poses dataset fixtures ---------------------------- @pytest.fixture def valid_position_array(): """Return a function that generates different kinds @@ -245,24 +376,27 @@ def _valid_position_array(array_type): @pytest.fixture def valid_poses_dataset(valid_position_array, request): """Return a valid pose tracks dataset.""" - dim_names = MovementDataset.dim_names + dim_names = ValidPosesDataset.DIM_NAMES # create a multi_individual_array by default unless overridden via param try: array_format = request.param except AttributeError: array_format = "multi_individual_array" position_array = valid_position_array(array_format) - n_individuals, n_keypoints = position_array.shape[1:3] + n_frames, n_individuals, n_keypoints = position_array.shape[:3] return xr.Dataset( data_vars={ "position": xr.DataArray(position_array, dims=dim_names), "confidence": xr.DataArray( - np.ones(position_array.shape[:-1]), + np.repeat( + np.linspace(0.1, 1.0, n_frames), + n_individuals * n_keypoints, + ).reshape(position_array.shape[:-1]), dims=dim_names[:-1], ), }, coords={ - "time": np.arange(position_array.shape[0]), + "time": np.arange(n_frames), "individuals": [f"ind{i}" for i in range(1, n_individuals + 1)], "keypoints": [f"key{i}" for i in range(1, n_keypoints + 1)], "space": ["x", "y"], @@ -272,19 +406,161 @@ def valid_poses_dataset(valid_position_array, request): "time_unit": "frames", "source_software": "SLEAP", "source_file": "test.h5", + "ds_type": "poses", }, ) +@pytest.fixture +def multi_view_dataset(): + view_names = ["view_0", "view_1"] + new_coord_views = xr.DataArray(view_names, dims="view") + dataset_list = [valid_poses_dataset() for _ in range(len(view_names))] + return xr.concat(dataset_list, dim=new_coord_views) + + @pytest.fixture def valid_poses_dataset_with_nan(valid_poses_dataset): """Return a valid pose tracks dataset with NaN values.""" + # Sets position for all keypoints in individual ind1 to NaN + # at timepoints 3, 7, 8 valid_poses_dataset.position.loc[ {"individuals": "ind1", "time": [3, 7, 8]} ] = np.nan return valid_poses_dataset +@pytest.fixture +def valid_poses_array_uniform_linear_motion(): + """Return a dictionary of valid arrays for a + ValidPosesDataset representing a uniform linear motion. + + It represents 2 individuals with 3 keypoints, for 10 frames, in 2D space. + - Individual 0 moves along the x=y line from the origin. + - Individual 1 moves along the x=-y line line from the origin. + + All confidence values for all keypoints are set to 0.9 except + for the keypoints at the following frames which are set to 0.1: + - Individual 0 at frames 2, 3, 4 + - Individual 1 at frames 2, 3 + """ + # define the shape of the arrays + n_frames, n_individuals, n_keypoints, n_space = (10, 2, 3, 2) + + # define centroid (index=0) trajectory in position array + # for each individual, the centroid moves along + # the x=+/-y line, starting from the origin. + # - individual 0 moves along x = y line + # - individual 1 moves along x = -y line + # They move one unit along x and y axes in each frame + frames = np.arange(n_frames) + position = np.empty((n_frames, n_individuals, n_keypoints, n_space)) + position[:, :, 0, 0] = frames[:, None] # reshape to (n_frames, 1) + position[:, 0, 0, 1] = frames + position[:, 1, 0, 1] = -frames + + # define trajectory of left and right keypoints + # for individual 0, at each timepoint: + # - the left keypoint (index=1) is at x_centroid, y_centroid + 1 + # - the right keypoint (index=2) is at x_centroid + 1, y_centroid + # for individual 1, at each timepoint: + # - the left keypoint (index=1) is at x_centroid - 1, y_centroid + # - the right keypoint (index=2) is at x_centroid, y_centroid + 1 + offsets = [ + [(0, 1), (1, 0)], # individual 0: left, right keypoints (x,y) offsets + [(-1, 0), (0, 1)], # individual 1: left, right keypoints (x,y) offsets + ] + for i in range(n_individuals): + for kpt in range(1, n_keypoints): + position[:, i, kpt, 0] = ( + position[:, i, 0, 0] + offsets[i][kpt - 1][0] + ) + position[:, i, kpt, 1] = ( + position[:, i, 0, 1] + offsets[i][kpt - 1][1] + ) + + # build an array of confidence values, all 0.9 + confidence = np.full((n_frames, n_individuals, n_keypoints), 0.9) + # set 5 low-confidence values + # - set 3 confidence values for individual id_0's centroid to 0.1 + # - set 2 confidence values for individual id_1's centroid to 0.1 + idx_start = 2 + confidence[idx_start : idx_start + 3, 0, 0] = 0.1 + confidence[idx_start : idx_start + 2, 1, 0] = 0.1 + + return {"position": position, "confidence": confidence} + + +@pytest.fixture +def valid_poses_dataset_uniform_linear_motion( + valid_poses_array_uniform_linear_motion, +): + """Return a valid poses dataset for two individuals moving in uniform + linear motion, with 5 frames with low confidence values and time in frames. + """ + dim_names = ValidPosesDataset.DIM_NAMES + + position_array = valid_poses_array_uniform_linear_motion["position"] + confidence_array = valid_poses_array_uniform_linear_motion["confidence"] + + n_frames, n_individuals, _, _ = position_array.shape + + return xr.Dataset( + data_vars={ + "position": xr.DataArray(position_array, dims=dim_names), + "confidence": xr.DataArray(confidence_array, dims=dim_names[:-1]), + }, + coords={ + dim_names[0]: np.arange(n_frames), + dim_names[1]: [f"id_{i}" for i in range(1, n_individuals + 1)], + dim_names[2]: ["centroid", "left", "right"], + dim_names[3]: ["x", "y"], + }, + attrs={ + "fps": None, + "time_unit": "frames", + "source_software": "test", + "source_file": "test_poses.h5", + "ds_type": "poses", + }, + ) + + +@pytest.fixture +def valid_poses_dataset_uniform_linear_motion_with_nans( + valid_poses_dataset_uniform_linear_motion, +): + """Return a valid poses dataset with NaN values in the position array. + + Specifically, we will introducde: + - 1 NaN value in the centroid keypoint of individual id_1 at time=0 + - 5 NaN values in the left keypoint of individual id_1 (frames 3-7) + - 10 NaN values in the right keypoint of individual id_1 (all frames) + """ + valid_poses_dataset_uniform_linear_motion.position.loc[ + { + "individuals": "id_1", + "keypoints": "centroid", + "time": 0, + } + ] = np.nan + valid_poses_dataset_uniform_linear_motion.position.loc[ + { + "individuals": "id_1", + "keypoints": "left", + "time": slice(3, 7), + } + ] = np.nan + valid_poses_dataset_uniform_linear_motion.position.loc[ + { + "individuals": "id_1", + "keypoints": "right", + } + ] = np.nan + return valid_poses_dataset_uniform_linear_motion + + +# -------------------- Invalid datasets fixtures ------------------------------ @pytest.fixture def not_a_dataset(): """Return data that is not a pose tracks dataset.""" @@ -298,73 +574,308 @@ def empty_dataset(): @pytest.fixture -def missing_var_dataset(valid_poses_dataset): - """Return a pose tracks dataset missing an expected variable.""" +def missing_var_poses_dataset(valid_poses_dataset): + """Return a poses dataset missing position variable.""" return valid_poses_dataset.drop_vars("position") @pytest.fixture -def missing_dim_dataset(valid_poses_dataset): - """Return a pose tracks dataset missing an expected dimension.""" +def missing_var_bboxes_dataset(valid_bboxes_dataset): + """Return a bboxes dataset missing position variable.""" + return valid_bboxes_dataset.drop_vars("position") + + +@pytest.fixture +def missing_two_vars_bboxes_dataset(valid_bboxes_dataset): + """Return a bboxes dataset missing position and shape variables.""" + return valid_bboxes_dataset.drop_vars(["position", "shape"]) + + +@pytest.fixture +def missing_dim_poses_dataset(valid_poses_dataset): + """Return a poses dataset missing the time dimension.""" return valid_poses_dataset.rename({"time": "tame"}) -@pytest.fixture( - params=[ - "not_a_dataset", - "empty_dataset", - "missing_var_dataset", - "missing_dim_dataset", - ] -) -def invalid_poses_dataset(request): - """Return an invalid pose tracks dataset.""" - return request.getfixturevalue(request.param) +@pytest.fixture +def missing_dim_bboxes_dataset(valid_bboxes_dataset): + """Return a bboxes dataset missing the time dimension.""" + return valid_bboxes_dataset.rename({"time": "tame"}) +@pytest.fixture +def missing_two_dims_bboxes_dataset(valid_bboxes_dataset): + """Return a bboxes dataset missing the time and space dimensions.""" + return valid_bboxes_dataset.rename({"time": "tame", "space": "spice"}) + + +# --------------------------- Kinematics fixtures --------------------------- @pytest.fixture(params=["displacement", "velocity", "acceleration"]) def kinematic_property(request): """Return a kinematic property.""" return request.param +# ---------------- VIA tracks CSV file fixtures ---------------------------- +@pytest.fixture +def via_tracks_csv_with_invalid_header(tmp_path): + """Return the file path for a file with invalid header.""" + file_path = tmp_path / "invalid_via_tracks.csv" + with open(file_path, "w") as f: + f.write("filename,file_size,file_attributes\n") + f.write("1,2,3") + return file_path + + +@pytest.fixture +def via_tracks_csv_with_valid_header(tmp_path): + file_path = tmp_path / "sample_via_tracks.csv" + with open(file_path, "w") as f: + f.write( + "filename," + "file_size," + "file_attributes," + "region_count," + "region_id," + "region_shape_attributes," + "region_attributes" + ) + f.write("\n") + return file_path + + +@pytest.fixture +def frame_number_in_file_attribute_not_integer( + via_tracks_csv_with_valid_header, +): + """Return the file path for a VIA tracks .csv file with invalid frame + number defined as file_attribute. + """ + file_path = via_tracks_csv_with_valid_header + with open(file_path, "a") as f: + f.write( + "04.09.2023-04-Right_RE_test_frame_A.png," + "26542080," + '"{""clip"":123, ""frame"":""FOO""}",' # frame number is a string + "1," + "0," + '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' + '"{""track"":""71""}"' + ) + return file_path + + +@pytest.fixture +def frame_number_in_filename_wrong_pattern( + via_tracks_csv_with_valid_header, +): + """Return the file path for a VIA tracks .csv file with invalid frame + number defined in the frame's filename. + """ + file_path = via_tracks_csv_with_valid_header + with open(file_path, "a") as f: + f.write( + "04.09.2023-04-Right_RE_test_frame_1.png," # frame not zero-padded + "26542080," + '"{""clip"":123}",' + "1," + "0," + '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' + '"{""track"":""71""}"' + ) + return file_path + + +@pytest.fixture +def more_frame_numbers_than_filenames( + via_tracks_csv_with_valid_header, +): + """Return the file path for a VIA tracks .csv file with more + frame numbers than filenames. + """ + file_path = via_tracks_csv_with_valid_header + with open(file_path, "a") as f: + f.write( + "04.09.2023-04-Right_RE_test.png," + "26542080," + '"{""clip"":123, ""frame"":24}",' + "1," + "0," + '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' + '"{""track"":""71""}"' + ) + f.write("\n") + f.write( + "04.09.2023-04-Right_RE_test.png," # same filename as previous row + "26542080," + '"{""clip"":123, ""frame"":25}",' # different frame number + "1," + "0," + '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' + '"{""track"":""71""}"' + ) + return file_path + + +@pytest.fixture +def less_frame_numbers_than_filenames( + via_tracks_csv_with_valid_header, +): + """Return the file path for a VIA tracks .csv file with with less + frame numbers than filenames. + """ + file_path = via_tracks_csv_with_valid_header + with open(file_path, "a") as f: + f.write( + "04.09.2023-04-Right_RE_test_A.png," + "26542080," + '"{""clip"":123, ""frame"":24}",' + "1," + "0," + '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' + '"{""track"":""71""}"' + ) + f.write("\n") + f.write( + "04.09.2023-04-Right_RE_test_B.png," # different filename + "26542080," + '"{""clip"":123, ""frame"":24}",' # same frame as previous row + "1," + "0," + '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' + '"{""track"":""71""}"' + ) + return file_path + + +@pytest.fixture +def region_shape_attribute_not_rect( + via_tracks_csv_with_valid_header, +): + """Return the file path for a VIA tracks .csv file with invalid shape in + region_shape_attributes. + """ + file_path = via_tracks_csv_with_valid_header + with open(file_path, "a") as f: + f.write( + "04.09.2023-04-Right_RE_test_frame_01.png," + "26542080," + '"{""clip"":123}",' + "1," + "0," + '"{""name"":""circle"",""cx"":1049,""cy"":1006,""r"":125}",' + '"{""track"":""71""}"' + ) # annotation of circular shape + return file_path + + +@pytest.fixture +def region_shape_attribute_missing_x( + via_tracks_csv_with_valid_header, +): + """Return the file path for a VIA tracks .csv file with missing `x` key in + region_shape_attributes. + """ + file_path = via_tracks_csv_with_valid_header + with open(file_path, "a") as f: + f.write( + "04.09.2023-04-Right_RE_test_frame_01.png," + "26542080," + '"{""clip"":123}",' + "1," + "0," + '"{""name"":""rect"",""y"":393.281,""width"":46,""height"":38}",' + '"{""track"":""71""}"' + ) # region_shape_attributes is missing ""x"" key + return file_path + + +@pytest.fixture +def region_attribute_missing_track( + via_tracks_csv_with_valid_header, +): + """Return the file path for a VIA tracks .csv file with missing track + attribute in region_attributes. + """ + file_path = via_tracks_csv_with_valid_header + with open(file_path, "a") as f: + f.write( + "04.09.2023-04-Right_RE_test_frame_01.png," + "26542080," + '"{""clip"":123}",' + "1," + "0," + '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' + '"{""foo"":""71""}"' # missing ""track"" + ) + return file_path + + +@pytest.fixture +def track_id_not_castable_as_int( + via_tracks_csv_with_valid_header, +): + """Return the file path for a VIA tracks .csv file with a track ID + attribute not castable as an integer. + """ + file_path = via_tracks_csv_with_valid_header + with open(file_path, "a") as f: + f.write( + "04.09.2023-04-Right_RE_test_frame_01.png," + "26542080," + '"{""clip"":123}",' + "1," + "0," + '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' + '"{""track"":""FOO""}"' # ""track"" not castable as int + ) + return file_path + + +@pytest.fixture +def track_ids_not_unique_per_frame( + via_tracks_csv_with_valid_header, +): + """Return the file path for a VIA tracks .csv file with a track ID + that appears twice in the same frame. + """ + file_path = via_tracks_csv_with_valid_header + with open(file_path, "a") as f: + f.write( + "04.09.2023-04-Right_RE_test_frame_01.png," + "26542080," + '"{""clip"":123}",' + "1," + "0," + '"{""name"":""rect"",""x"":526.236,""y"":393.281,""width"":46,""height"":38}",' + '"{""track"":""71""}"' + ) + f.write("\n") + f.write( + "04.09.2023-04-Right_RE_test_frame_01.png," + "26542080," + '"{""clip"":123}",' + "1," + "0," + '"{""name"":""rect"",""x"":2567.627,""y"":466.888,""width"":40,""height"":37}",' + '"{""track"":""71""}"' # same track ID as the previous row + ) + return file_path + + +# ----------------- Helpers fixture ----------------- class Helpers: - """Generic helper methods for ``movement`` testing modules.""" + """Generic helper methods for ``movement`` test modules.""" @staticmethod - def count_nans(ds): - """Count NaNs in the x coordinate timeseries of the first keypoint - of the first individual in the dataset. - """ - n_nans = np.count_nonzero( - np.isnan( - ds.position.isel(individuals=0, keypoints=0, space=0).values - ) - ) - return n_nans + def count_nans(da): + """Count number of NaNs in a DataArray.""" + return da.isnull().sum().item() @staticmethod - def count_nan_repeats(ds): - """Count the number of continuous stretches of NaNs in the - x coordinate timeseries of the first keypoint of the first individual - in the dataset. - """ - x = ds.position.isel(individuals=0, keypoints=0, space=0).values - repeats = [] - running_count = 1 - for i in range(len(x)): - if i != len(x) - 1: - if np.isnan(x[i]) and np.isnan(x[i + 1]): - running_count += 1 - elif np.isnan(x[i]): - repeats.append(running_count) - running_count = 1 - else: - running_count = 1 - elif np.isnan(x[i]): - repeats.append(running_count) - running_count = 1 - return len(repeats) + def count_consecutive_nans(da): + """Count occurrences of consecutive NaNs in a DataArray.""" + return (da.isnull().astype(int).diff("time") == 1).sum().item() @pytest.fixture diff --git a/tests/test_integration/test_filtering.py b/tests/test_integration/test_filtering.py index 1f55f9ca..e3e87901 100644 --- a/tests/test_integration/test_filtering.py +++ b/tests/test_integration/test_filtering.py @@ -3,70 +3,73 @@ from movement.filtering import ( filter_by_confidence, interpolate_over_time, - median_filter, savgol_filter, ) from movement.io import load_poses from movement.sample_data import fetch_dataset_paths -@pytest.fixture(scope="module") +@pytest.fixture def sample_dataset(): - """Return a single-animal sample dataset, with time unit in frames. - This allows us to better control the expected number of NaNs in the tests. - """ + """Return a single-animal sample dataset, with time unit in frames.""" ds_path = fetch_dataset_paths("DLC_single-mouse_EPM.predictions.h5")[ "poses" ] - return load_poses.from_dlc_file(ds_path, fps=None) + ds = load_poses.from_dlc_file(ds_path) + return ds -@pytest.mark.parametrize("window_length", [3, 5, 6, 13]) -def test_nan_propagation_through_filters( - sample_dataset, window_length, helpers -): - """Tests how NaNs are propagated when passing a dataset through multiple - filters sequentially. For the ``median_filter`` and ``savgol_filter``, - we expect the number of NaNs to increase at most by the filter's window - length minus one (``window_length - 1``) multiplied by the number of - continuous stretches of NaNs present in the input dataset. +@pytest.mark.parametrize("window", [3, 5, 6, 13]) +def test_nan_propagation_through_filters(sample_dataset, window, helpers): + """Test NaN propagation is as expected when passing a DataArray through + filter by confidence, Savgol filter and interpolation. + For the ``savgol_filter``, the number of NaNs is expected to increase + at most by the filter's window length minus one (``window - 1``) + multiplied by the number of consecutive NaNs in the input data. """ - # Introduce nans via filter_by_confidence - ds_with_nans = filter_by_confidence(sample_dataset, threshold=0.6) - nans_after_confilt = helpers.count_nans(ds_with_nans) - nan_repeats_after_confilt = helpers.count_nan_repeats(ds_with_nans) - assert nans_after_confilt == 2555, ( - f"Unexpected number of NaNs in filtered dataset: " - f"expected: 2555, got: {nans_after_confilt}" + # Compute number of low confidence keypoints + n_low_confidence_kpts = (sample_dataset.confidence.data < 0.6).sum() + + # Check filter position by confidence creates correct number of NaNs + sample_dataset.update( + { + "position": filter_by_confidence( + sample_dataset.position, + sample_dataset.confidence, + ) + } ) + n_total_nans_input = helpers.count_nans(sample_dataset.position) - # Apply median filter and check that - # it doesn't introduce too many or too few NaNs - ds_medfilt = median_filter(ds_with_nans, window_length) - nans_after_medfilt = helpers.count_nans(ds_medfilt) - nan_repeats_after_medfilt = helpers.count_nan_repeats(ds_medfilt) - max_nans_increase = (window_length - 1) * nan_repeats_after_confilt - assert ( - nans_after_medfilt <= nans_after_confilt + max_nans_increase - ), "Median filter introduced more NaNs than expected." assert ( - nans_after_medfilt >= nans_after_confilt - ), "Median filter mysteriously removed NaNs." + n_total_nans_input + == n_low_confidence_kpts * sample_dataset.sizes["space"] + ) - # Apply savgol filter and check that - # it doesn't introduce too many or too few NaNs - ds_savgol = savgol_filter( - ds_medfilt, window_length, polyorder=2, print_report=True + # Compute maximum expected increase in NaNs due to filtering + n_consecutive_nans_input = helpers.count_consecutive_nans( + sample_dataset.position ) - nans_after_savgol = helpers.count_nans(ds_savgol) - max_nans_increase = (window_length - 1) * nan_repeats_after_medfilt - assert ( - nans_after_savgol <= nans_after_medfilt + max_nans_increase - ), "Savgol filter introduced more NaNs than expected." - assert ( - nans_after_savgol >= nans_after_medfilt - ), "Savgol filter mysteriously removed NaNs." + max_nans_increase = (window - 1) * n_consecutive_nans_input - # Apply interpolate_over_time (without max_gap) to eliminate all NaNs - ds_interpolated = interpolate_over_time(ds_savgol, print_report=True) - assert helpers.count_nans(ds_interpolated) == 0 + # Apply savgol filter and check that number of NaNs is within threshold + sample_dataset.update( + { + "position": savgol_filter( + sample_dataset.position, window, polyorder=2 + ) + } + ) + + n_total_nans_savgol = helpers.count_nans(sample_dataset.position) + + # Check that filtering does not reduce number of nans + assert n_total_nans_savgol >= n_total_nans_input + # Check that the increase in nans is below the expected threshold + assert n_total_nans_savgol - n_total_nans_input <= max_nans_increase + + # Interpolate data (without max_gap) and check it eliminates all NaNs + sample_dataset.update( + {"position": interpolate_over_time(sample_dataset.position)} + ) + assert helpers.count_nans(sample_dataset.position) == 0 diff --git a/tests/test_integration/test_io.py b/tests/test_integration/test_io.py index fc291eff..50f03933 100644 --- a/tests/test_integration/test_io.py +++ b/tests/test_integration/test_io.py @@ -2,7 +2,7 @@ import numpy as np import pytest import xarray as xr -from pytest import POSE_DATA_PATHS +from pytest import DATA_PATHS from movement.io import load_poses, save_poses @@ -15,12 +15,12 @@ def dlc_output_file(self, request, tmp_path): """Return the output file path for a DLC .h5 or .csv file.""" return tmp_path / request.param - def test_load_and_save_to_dlc_df(self, dlc_style_df): + def test_load_and_save_to_dlc_style_df(self, dlc_style_df): """Test that loading pose tracks from a DLC-style DataFrame and converting back to a DataFrame returns the same data values. """ - ds = load_poses.from_dlc_df(dlc_style_df) - df = save_poses.to_dlc_df(ds, split_individuals=False) + ds = load_poses.from_dlc_style_df(dlc_style_df) + df = save_poses.to_dlc_style_df(ds, split_individuals=False) np.testing.assert_allclose(df.values, dlc_style_df.values) def test_save_and_load_dlc_file( @@ -62,7 +62,7 @@ def test_to_sleap_analysis_file_returns_same_h5_file_content( file) to a SLEAP-style .h5 analysis file returns the same file contents. """ - sleap_h5_file_path = POSE_DATA_PATHS.get(sleap_h5_file) + sleap_h5_file_path = DATA_PATHS.get(sleap_h5_file) ds = load_poses.from_sleap_file(sleap_h5_file_path, fps=fps) save_poses.to_sleap_analysis_file(ds, new_h5_file) @@ -93,7 +93,7 @@ def test_to_sleap_analysis_file_source_file(self, file, new_h5_file): to a SLEAP-style .h5 analysis file stores the .slp labels path only when the source file is a .slp file. """ - file_path = POSE_DATA_PATHS.get(file) + file_path = DATA_PATHS.get(file) if file.startswith("DLC"): ds = load_poses.from_dlc_file(file_path) else: diff --git a/tests/test_integration/test_kinematics_vector_transform.py b/tests/test_integration/test_kinematics_vector_transform.py index c81183ea..e91f4d64 100644 --- a/tests/test_integration/test_kinematics_vector_transform.py +++ b/tests/test_integration/test_kinematics_vector_transform.py @@ -1,33 +1,96 @@ -from contextlib import nullcontext as does_not_raise +import math +import numpy as np import pytest import xarray as xr +import movement.kinematics as kin from movement.utils import vector -class TestKinematicsVectorTransform: - """Test the vector transformation functionality with - various kinematic properties. +@pytest.mark.parametrize( + "valid_dataset_uniform_linear_motion", + [ + "valid_poses_dataset_uniform_linear_motion", + "valid_bboxes_dataset", + ], +) +@pytest.mark.parametrize( + "kinematic_variable, expected_kinematics_polar", + [ + ( + "displacement", + [ + np.vstack( + [ + np.zeros((1, 2)), + np.tile([math.sqrt(2), math.atan(1)], (9, 1)), + ], + ), # Individual 0, rho=sqrt(2), phi=45deg + np.vstack( + [ + np.zeros((1, 2)), + np.tile([math.sqrt(2), -math.atan(1)], (9, 1)), + ] + ), # Individual 1, rho=sqrt(2), phi=-45deg + ], + ), + ( + "velocity", + [ + np.tile( + [math.sqrt(2), math.atan(1)], (10, 1) + ), # Individual O, rho, phi=45deg + np.tile( + [math.sqrt(2), -math.atan(1)], (10, 1) + ), # Individual 1, rho, phi=-45deg + ], + ), + ( + "acceleration", + [ + np.zeros((10, 2)), # Individual 0 + np.zeros((10, 2)), # Individual 1 + ], + ), + ], +) +def test_cart2pol_transform_on_kinematics( + valid_dataset_uniform_linear_motion, + kinematic_variable, + expected_kinematics_polar, + request, +): + """Test transformation between Cartesian and polar coordinates + with various kinematic properties. """ + ds = request.getfixturevalue(valid_dataset_uniform_linear_motion) + kinematic_array_cart = getattr(kin, f"compute_{kinematic_variable}")( + ds.position + ) + kinematic_array_pol = vector.cart2pol(kinematic_array_cart) + + # Build expected data array + expected_array_pol = xr.DataArray( + np.stack(expected_kinematics_polar, axis=1), + # Stack along the "individuals" axis + dims=["time", "individuals", "space"], + ) + if "keypoints" in ds.position.coords: + expected_array_pol = expected_array_pol.expand_dims( + {"keypoints": ds.position.coords["keypoints"].size} + ) + expected_array_pol = expected_array_pol.transpose( + "time", "individuals", "keypoints", "space" + ) + + # Compare the values of the kinematic_array against the expected_array + np.testing.assert_allclose( + kinematic_array_pol.values, expected_array_pol.values + ) - @pytest.mark.parametrize( - "ds, expected_exception", - [ - ("valid_poses_dataset", does_not_raise()), - ("valid_poses_dataset_with_nan", does_not_raise()), - ("missing_dim_dataset", pytest.raises(ValueError)), - ], + # Check we can recover the original Cartesian array + kinematic_array_cart_recover = vector.pol2cart(kinematic_array_pol) + xr.testing.assert_allclose( + kinematic_array_cart, kinematic_array_cart_recover ) - def test_cart_and_pol_transform( - self, ds, expected_exception, kinematic_property, request - ): - """Test transformation between Cartesian and polar coordinates - with various kinematic properties. - """ - ds = request.getfixturevalue(ds) - with expected_exception: - data = getattr(ds.move, f"compute_{kinematic_property}")() - pol_data = vector.cart2pol(data) - cart_data = vector.pol2cart(pol_data) - xr.testing.assert_allclose(cart_data, data) diff --git a/tests/test_unit/test_filtering.py b/tests/test_unit/test_filtering.py index a3500b27..d51af1be 100644 --- a/tests/test_unit/test_filtering.py +++ b/tests/test_unit/test_filtering.py @@ -1,127 +1,289 @@ +from contextlib import nullcontext as does_not_raise + import pytest import xarray as xr from movement.filtering import ( filter_by_confidence, interpolate_over_time, - log_to_attrs, median_filter, savgol_filter, ) -from movement.sample_data import fetch_dataset - -@pytest.fixture(scope="module") -def sample_dataset(): - """Return a single-animal sample dataset, with time unit in seconds.""" - return fetch_dataset("DLC_single-mouse_EPM.predictions.h5") +# Dataset fixtures +list_valid_datasets_without_nans = [ + "valid_poses_dataset", + "valid_bboxes_dataset", +] +list_valid_datasets_with_nans = [ + f"{dataset}_with_nan" for dataset in list_valid_datasets_without_nans +] +list_all_valid_datasets = ( + list_valid_datasets_without_nans + list_valid_datasets_with_nans +) -def test_log_to_attrs(sample_dataset): - """Test for the ``log_to_attrs()`` decorator. Decorates a mock function and - checks that ``attrs`` contains all expected values. +@pytest.mark.parametrize( + "valid_dataset_with_nan", + list_valid_datasets_with_nans, +) +@pytest.mark.parametrize( + "max_gap, expected_n_nans_in_position", [(None, 0), (0, 3), (1, 2), (2, 0)] +) +def test_interpolate_over_time_on_position( + valid_dataset_with_nan, + max_gap, + expected_n_nans_in_position, + helpers, + request, +): + """Test that the number of NaNs decreases after linearly interpolating + over time and that the resulting number of NaNs is as expected + for different values of ``max_gap``. """ + valid_dataset_in_frames = request.getfixturevalue(valid_dataset_with_nan) - @log_to_attrs - def fake_func(ds, arg, kwarg=None): - return ds + # Get position array with time unit in frames & seconds + # assuming 10 fps = 0.1 s per frame + valid_dataset_in_seconds = valid_dataset_in_frames.copy() + valid_dataset_in_seconds.coords["time"] = ( + valid_dataset_in_seconds.coords["time"] * 0.1 + ) + position = { + "frames": valid_dataset_in_frames.position, + "seconds": valid_dataset_in_seconds.position, + } - ds = fake_func(sample_dataset, "test1", kwarg="test2") + # Count number of NaNs before and after interpolating position + n_nans_before = helpers.count_nans(position["frames"]) + n_nans_after_per_time_unit = {} + for time_unit in ["frames", "seconds"]: + # interpolate + position_interp = interpolate_over_time( + position[time_unit], method="linear", max_gap=max_gap + ) + # count nans + n_nans_after_per_time_unit[time_unit] = helpers.count_nans( + position_interp + ) - assert "log" in ds.attrs - assert ds.attrs["log"][0]["operation"] == "fake_func" + # The number of NaNs should be the same for both datasets + # as max_gap is based on number of missing observations (NaNs) assert ( - ds.attrs["log"][0]["arg_1"] == "test1" - and ds.attrs["log"][0]["kwarg"] == "test2" + n_nans_after_per_time_unit["frames"] + == n_nans_after_per_time_unit["seconds"] ) + # The number of NaNs should decrease after interpolation + n_nans_after = n_nans_after_per_time_unit["frames"] + if max_gap == 0: + assert n_nans_after == n_nans_before + else: + assert n_nans_after < n_nans_before -def test_interpolate_over_time(sample_dataset, helpers): - """Test the ``interpolate_over_time`` function. + # The number of NaNs after interpolating should be as expected + assert n_nans_after == ( + valid_dataset_in_frames.sizes["space"] + * valid_dataset_in_frames.sizes.get("keypoints", 1) + # in bboxes dataset there is no keypoints dimension + * expected_n_nans_in_position + ) - Check that the number of nans is decreased after running this function - on a filtered dataset - """ - ds_filtered = filter_by_confidence(sample_dataset) - ds_interpolated = interpolate_over_time(ds_filtered) - assert helpers.count_nans(ds_interpolated) < helpers.count_nans( - ds_filtered +@pytest.mark.parametrize( + "valid_dataset_no_nans, n_low_confidence_kpts", + [ + ("valid_poses_dataset", 20), + ("valid_bboxes_dataset", 5), + ], +) +def test_filter_by_confidence_on_position( + valid_dataset_no_nans, n_low_confidence_kpts, helpers, request +): + """Test that points below the default 0.6 confidence threshold + are converted to NaN. + """ + # Filter position by confidence + valid_input_dataset = request.getfixturevalue(valid_dataset_no_nans) + position_filtered = filter_by_confidence( + valid_input_dataset.position, + confidence=valid_input_dataset.confidence, + threshold=0.6, ) + # Count number of NaNs in the full array + n_nans = helpers.count_nans(position_filtered) + + # expected number of nans for poses: + # 5 timepoints * 2 individuals * 2 keypoints + # Note: we count the number of nans in the array, so we multiply + # the number of low confidence keypoints by the number of + # space dimensions + assert isinstance(position_filtered, xr.DataArray) + assert n_nans == valid_input_dataset.sizes["space"] * n_low_confidence_kpts + -def test_filter_by_confidence(sample_dataset, caplog, helpers): - """Tests for the ``filter_by_confidence()`` function. - Checks that the function filters the expected amount of values - from a known dataset, and tests that this value is logged - correctly. +@pytest.mark.parametrize( + "valid_dataset", + list_all_valid_datasets, +) +@pytest.mark.parametrize( + ("filter_func, filter_kwargs"), + [ + (median_filter, {"window": 2}), + (median_filter, {"window": 4}), + (savgol_filter, {"window": 2, "polyorder": 1}), + (savgol_filter, {"window": 4, "polyorder": 2}), + ], +) +def test_filter_on_position( + filter_func, filter_kwargs, valid_dataset, request +): + """Test that applying a filter to the position data returns + a different xr.DataArray than the input position data. """ - ds_filtered = filter_by_confidence(sample_dataset, threshold=0.6) + # Filter position + valid_input_dataset = request.getfixturevalue(valid_dataset) + position_filtered = filter_func( + valid_input_dataset.position, **filter_kwargs + ) - assert isinstance(ds_filtered, xr.Dataset) + del position_filtered.attrs["log"] - n_nans = helpers.count_nans(ds_filtered) - assert n_nans == 2555 + # filtered array is an xr.DataArray + assert isinstance(position_filtered, xr.DataArray) - # Check that diagnostics are being logged correctly - assert f"snout: {n_nans}/{ds_filtered.time.values.shape[0]}" in caplog.text + # filtered data should not be equal to the original data + assert not position_filtered.equals(valid_input_dataset.position) -@pytest.mark.parametrize("window_size", [0.2, 1, 4, 12]) -def test_median_filter(sample_dataset, window_size): - """Tests for the ``median_filter()`` function. Checks that - the function successfully receives the input data and - returns a different xr.Dataset with the correct dimensions. +# Expected number of nans in the position array per +# individual, after applying a filter with window size 3 +@pytest.mark.parametrize( + ("valid_dataset, expected_nans_in_filtered_position_per_indiv"), + [ + ( + "valid_poses_dataset", + {0: 0, 1: 0}, + ), # filtering should not introduce nans if input has no nans + ("valid_bboxes_dataset", {0: 0, 1: 0}), + ("valid_poses_dataset_with_nan", {0: 7, 1: 0}), + ("valid_bboxes_dataset_with_nan", {0: 7, 1: 0}), + ], +) +@pytest.mark.parametrize( + ("filter_func, filter_kwargs"), + [ + (median_filter, {"window": 3}), + (savgol_filter, {"window": 3, "polyorder": 2}), + ], +) +def test_filter_with_nans_on_position( + filter_func, + filter_kwargs, + valid_dataset, + expected_nans_in_filtered_position_per_indiv, + helpers, + request, +): + """Test NaN behaviour of the selected filter. The median and SG filters + should set all values to NaN if one element of the sliding window is NaN. """ - ds_smoothed = median_filter(sample_dataset, window_size) - # Test whether filter received and returned correct data - assert isinstance(ds_smoothed, xr.Dataset) and ~( - ds_smoothed == sample_dataset - ) - assert ds_smoothed.position.shape == sample_dataset.position.shape + def _assert_n_nans_in_position_per_individual( + valid_input_dataset, + position_filtered, + expected_nans_in_filt_position_per_indiv, + ): + # compute n nans in position after filtering per individual + n_nans_after_filtering_per_indiv = { + i: helpers.count_nans(position_filtered.isel(individuals=i)) + for i in range(valid_input_dataset.sizes["individuals"]) + } + # check number of nans per indiv is as expected + for i in range(valid_input_dataset.sizes["individuals"]): + assert n_nans_after_filtering_per_indiv[i] == ( + expected_nans_in_filt_position_per_indiv[i] + * valid_input_dataset.sizes["space"] + * valid_input_dataset.sizes.get("keypoints", 1) + ) -def test_median_filter_with_nans(valid_poses_dataset_with_nan, helpers): - """Test nan behavior of the ``median_filter()`` function. The - ``valid_poses_dataset_with_nan`` dataset (fixture defined in conftest.py) - contains NaN values in all keypoints of the first individual at times - 3, 7, and 8 (0-indexed, 10 total timepoints). - The median filter should propagate NaNs within the windows of the filter, - but it should not introduce any NaNs for the second individual. - """ - ds_smoothed = median_filter(valid_poses_dataset_with_nan, 3) - # There should be NaNs at 7 timepoints for the first individual - # all except for timepoints 0, 1 and 5 - assert helpers.count_nans(ds_smoothed) == 7 - assert ( - ~ds_smoothed.position.isel(individuals=0, time=[0, 1, 5]) - .isnull() - .any() + # Filter position + valid_input_dataset = request.getfixturevalue(valid_dataset) + position_filtered = filter_func( + valid_input_dataset.position, **filter_kwargs + ) + + # check number of nans per indiv is as expected + _assert_n_nans_in_position_per_individual( + valid_input_dataset, + position_filtered, + expected_nans_in_filtered_position_per_indiv, ) - # The second individual should not contain any NaNs - assert ~ds_smoothed.position.sel(individuals="ind2").isnull().any() + + # if input had nans, + # individual 1's position at exact timepoints 0, 1 and 5 is not nan + n_nans_input = helpers.count_nans(valid_input_dataset.position) + if n_nans_input != 0: + assert not ( + position_filtered.isel(individuals=0, time=[0, 1, 5]) + .isnull() + .any() + ) -@pytest.mark.parametrize("window_length", [0.2, 1, 4, 12]) -@pytest.mark.parametrize("polyorder", [1, 2, 3]) -def test_savgol_filter(sample_dataset, window_length, polyorder): - """Tests for the ``savgol_filter()`` function. - Checks that the function successfully receives the input - data and returns a different xr.Dataset with the correct - dimensions. +@pytest.mark.parametrize( + "valid_dataset_with_nan", + list_valid_datasets_with_nans, +) +@pytest.mark.parametrize( + "window", + [3, 5, 6, 10], # data is nframes = 10 +) +@pytest.mark.parametrize( + "filter_func", + [median_filter, savgol_filter], +) +def test_filter_with_nans_on_position_varying_window( + valid_dataset_with_nan, window, filter_func, helpers, request +): + """Test that the number of NaNs in the filtered position data + increases at most by the filter's window length minus one + multiplied by the number of consecutive NaNs in the input data. """ - ds_smoothed = savgol_filter( - sample_dataset, window_length, polyorder=polyorder + # Prepare kwargs per filter + kwargs = {"window": window} + if filter_func == savgol_filter: + kwargs["polyorder"] = 2 + + # Filter position + valid_input_dataset = request.getfixturevalue(valid_dataset_with_nan) + position_filtered = filter_func( + valid_input_dataset.position, + **kwargs, ) - # Test whether filter received and returned correct data - assert isinstance(ds_smoothed, xr.Dataset) and ~( - ds_smoothed == sample_dataset + # Count number of NaNs in the input and filtered position data + n_total_nans_initial = helpers.count_nans(valid_input_dataset.position) + n_consecutive_nans_initial = helpers.count_consecutive_nans( + valid_input_dataset.position ) - assert ds_smoothed.position.shape == sample_dataset.position.shape + + n_total_nans_filtered = helpers.count_nans(position_filtered) + + max_nans_increase = (window - 1) * n_consecutive_nans_initial + + # Check that filtering does not reduce number of nans + assert n_total_nans_filtered >= n_total_nans_initial + # Check that the increase in nans is below the expected threshold + assert n_total_nans_filtered - n_total_nans_initial <= max_nans_increase +@pytest.mark.parametrize( + "valid_dataset", + list_all_valid_datasets, +) @pytest.mark.parametrize( "override_kwargs", [ @@ -130,36 +292,20 @@ def test_savgol_filter(sample_dataset, window_length, polyorder): {"mode": "nearest", "axis": 1}, ], ) -def test_savgol_filter_kwargs_override(sample_dataset, override_kwargs): - """Further tests for the ``savgol_filter()`` function. - Checks that the function raises a ValueError when the ``axis`` keyword - argument is overridden, as this is not allowed. Overriding other keyword - arguments (e.g. ``mode``) should not raise an error. +def test_savgol_filter_kwargs_override( + valid_dataset, override_kwargs, request +): + """Test that overriding keyword arguments in the Savitzky-Golay filter + works, except for the ``axis`` argument, which should raise a ValueError. """ - if "axis" in override_kwargs: - with pytest.raises(ValueError): - savgol_filter(sample_dataset, 5, **override_kwargs) - else: - ds_smoothed = savgol_filter(sample_dataset, 5, **override_kwargs) - assert isinstance(ds_smoothed, xr.Dataset) - - -def test_savgol_filter_with_nans(valid_poses_dataset_with_nan, helpers): - """Test nan behavior of the ``savgol_filter()`` function. The - ``valid_poses_dataset_with_nan`` dataset (fixture defined in conftest.py) - contains NaN values in all keypoints of the first individual at times - 3, 7, and 8 (0-indexed, 10 total timepoints). - The Savitzky-Golay filter should propagate NaNs within the windows of - the filter, but it should not introduce any NaNs for the second individual. - """ - ds_smoothed = savgol_filter(valid_poses_dataset_with_nan, 3, polyorder=2) - # There should be NaNs at 7 timepoints for the first individual - # all except for timepoints 0, 1 and 5 - assert helpers.count_nans(ds_smoothed) == 7 - assert ( - ~ds_smoothed.position.isel(individuals=0, time=[0, 1, 5]) - .isnull() - .any() + expected_exception = ( + pytest.raises(ValueError) + if "axis" in override_kwargs + else does_not_raise() ) - # The second individual should not contain any NaNs - assert ~ds_smoothed.position.sel(individuals="ind2").isnull().any() + with expected_exception: + savgol_filter( + request.getfixturevalue(valid_dataset).position, + window=3, + **override_kwargs, + ) diff --git a/tests/test_unit/test_kinematics.py b/tests/test_unit/test_kinematics.py index 2d9096bc..c439bf54 100644 --- a/tests/test_unit/test_kinematics.py +++ b/tests/test_unit/test_kinematics.py @@ -1,112 +1,751 @@ +import re from contextlib import nullcontext as does_not_raise import numpy as np import pytest import xarray as xr -from movement.analysis import kinematics - - -class TestKinematics: - """Test suite for the kinematics module.""" - - @pytest.fixture - def expected_dataarray(self, valid_poses_dataset): - """Return a function to generate the expected dataarray - for different kinematic properties. - """ - - def _expected_dataarray(property): - """Return an xarray.DataArray with default values and - the expected dimensions and coordinates. - """ - # Expected x,y values for velocity - x_vals = np.array( - [1.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 17.0] - ) - y_vals = np.full((10, 2, 2, 1), 4.0) - if property == "acceleration": - x_vals = np.array( - [1.0, 1.5, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 1.5, 1.0] - ) - y_vals = np.full((10, 2, 2, 1), 0) - elif property == "displacement": - x_vals = np.array( - [0.0, 1.0, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0] - ) - y_vals[0] = 0 - - x_vals = x_vals.reshape(-1, 1, 1, 1) - # Repeat the x_vals to match the shape of the position - x_vals = np.tile(x_vals, (1, 2, 2, 1)) - return xr.DataArray( - np.concatenate( - [x_vals, y_vals], - axis=-1, - ), - dims=valid_poses_dataset.dims, - coords=valid_poses_dataset.coords, - ) - - return _expected_dataarray - - kinematic_test_params = [ - ("valid_poses_dataset", does_not_raise()), - ("valid_poses_dataset_with_nan", does_not_raise()), - ("missing_dim_dataset", pytest.raises(ValueError)), +from movement import kinematics + + +@pytest.mark.parametrize( + "valid_dataset_uniform_linear_motion", + [ + "valid_poses_dataset_uniform_linear_motion", + "valid_bboxes_dataset", + ], +) +@pytest.mark.parametrize( + "kinematic_variable, expected_kinematics", + [ + ( + "displacement", + [ + np.vstack([np.zeros((1, 2)), np.ones((9, 2))]), # Individual 0 + np.multiply( + np.vstack([np.zeros((1, 2)), np.ones((9, 2))]), + np.array([1, -1]), + ), # Individual 1 + ], + ), + ( + "velocity", + [ + np.ones((10, 2)), # Individual 0 + np.multiply( + np.ones((10, 2)), np.array([1, -1]) + ), # Individual 1 + ], + ), + ( + "acceleration", + [ + np.zeros((10, 2)), # Individual 0 + np.zeros((10, 2)), # Individual 1 + ], + ), + ( + "speed", # magnitude of velocity + [ + np.ones(10) * np.sqrt(2), # Individual 0 + np.ones(10) * np.sqrt(2), # Individual 1 + ], + ), + ], +) +def test_kinematics_uniform_linear_motion( + valid_dataset_uniform_linear_motion, + kinematic_variable, + expected_kinematics, # 2D: n_frames, n_space_dims + request, +): + """Test computed kinematics for a uniform linear motion case. + + Uniform linear motion means the individuals move along a line + at constant velocity. + + We consider 2 individuals ("id_0" and "id_1"), + tracked for 10 frames, along x and y: + - id_0 moves along x=y line from the origin + - id_1 moves along x=-y line from the origin + - they both move one unit (pixel) along each axis in each frame + + If the dataset is a poses dataset, we consider 3 keypoints per individual + (centroid, left, right), that are always in front of the centroid keypoint + at 45deg from the trajectory. + """ + # Compute kinematic array from input dataset + position = request.getfixturevalue( + valid_dataset_uniform_linear_motion + ).position + kinematic_array = getattr(kinematics, f"compute_{kinematic_variable}")( + position + ) + + # Figure out which dimensions to expect in kinematic_array + # and in the final xarray.DataArray + expected_dims = ["time", "individuals"] + if kinematic_variable in ["displacement", "velocity", "acceleration"]: + expected_dims.append("space") + + # Build expected data array from the expected numpy array + expected_array = xr.DataArray( + # Stack along the "individuals" axis + np.stack(expected_kinematics, axis=1), + dims=expected_dims, + ) + if "keypoints" in position.coords: + expected_array = expected_array.expand_dims( + {"keypoints": position.coords["keypoints"].size} + ) + expected_dims.insert(2, "keypoints") + expected_array = expected_array.transpose(*expected_dims) + + # Compare the values of the kinematic_array against the expected_array + np.testing.assert_allclose(kinematic_array.values, expected_array.values) + + +@pytest.mark.parametrize( + "valid_dataset_with_nan", + [ + "valid_poses_dataset_with_nan", + "valid_bboxes_dataset_with_nan", + ], +) +@pytest.mark.parametrize( + "kinematic_variable, expected_nans_per_individual", + [ + ("displacement", [5, 0]), # individual 0, individual 1 + ("velocity", [6, 0]), + ("acceleration", [7, 0]), + ("speed", [6, 0]), + ], +) +def test_kinematics_with_dataset_with_nans( + valid_dataset_with_nan, + kinematic_variable, + expected_nans_per_individual, + helpers, + request, +): + """Test kinematics computation for a dataset with nans. + + We test that the kinematics can be computed and that the number + of nan values in the kinematic array is as expected. + + """ + # compute kinematic array + valid_dataset = request.getfixturevalue(valid_dataset_with_nan) + position = valid_dataset.position + kinematic_array = getattr(kinematics, f"compute_{kinematic_variable}")( + position + ) + + # compute n nans in kinematic array per individual + n_nans_kinematics_per_indiv = [ + helpers.count_nans(kinematic_array.isel(individuals=i)) + for i in range(valid_dataset.sizes["individuals"]) ] - @pytest.mark.parametrize("ds, expected_exception", kinematic_test_params) - def test_displacement( - self, ds, expected_exception, expected_dataarray, request - ): - """Test displacement computation.""" - ds = request.getfixturevalue(ds) - with expected_exception: - result = kinematics.compute_displacement(ds.position) - expected = expected_dataarray("displacement") - if ds.position.isnull().any(): - expected.loc[ - {"individuals": "ind1", "time": [3, 4, 7, 8, 9]} - ] = np.nan - xr.testing.assert_allclose(result, expected) - - @pytest.mark.parametrize("ds, expected_exception", kinematic_test_params) - def test_velocity( - self, ds, expected_exception, expected_dataarray, request - ): - """Test velocity computation.""" - ds = request.getfixturevalue(ds) - with expected_exception: - result = kinematics.compute_velocity(ds.position) - expected = expected_dataarray("velocity") - if ds.position.isnull().any(): - expected.loc[ - {"individuals": "ind1", "time": [2, 4, 6, 7, 8, 9]} - ] = np.nan - xr.testing.assert_allclose(result, expected) - - @pytest.mark.parametrize("ds, expected_exception", kinematic_test_params) - def test_acceleration( - self, ds, expected_exception, expected_dataarray, request - ): - """Test acceleration computation.""" - ds = request.getfixturevalue(ds) - with expected_exception: - result = kinematics.compute_acceleration(ds.position) - expected = expected_dataarray("acceleration") - if ds.position.isnull().any(): - expected.loc[ - {"individuals": "ind1", "time": [1, 3, 5, 6, 7, 8, 9]} - ] = np.nan - xr.testing.assert_allclose(result, expected) - - @pytest.mark.parametrize("order", [0, -1, 1.0, "1"]) - def test_approximate_derivative_with_invalid_order(self, order): - """Test that an error is raised when the order is non-positive.""" - data = np.arange(10) - expected_exception = ( - ValueError if isinstance(order, int) else TypeError + # expected nans per individual adjusted for space and keypoints dimensions + if "space" in kinematic_array.dims: + n_space_dims = position.sizes["space"] + else: + n_space_dims = 1 + + expected_nans_adjusted = [ + n * n_space_dims * valid_dataset.sizes.get("keypoints", 1) + for n in expected_nans_per_individual + ] + # check number of nans per individual is as expected in kinematic array + np.testing.assert_array_equal( + n_nans_kinematics_per_indiv, expected_nans_adjusted + ) + + +@pytest.mark.parametrize( + "invalid_dataset, expected_exception", + [ + ("not_a_dataset", pytest.raises(AttributeError)), + ("empty_dataset", pytest.raises(AttributeError)), + ("missing_var_poses_dataset", pytest.raises(AttributeError)), + ("missing_var_bboxes_dataset", pytest.raises(AttributeError)), + ("missing_dim_poses_dataset", pytest.raises(ValueError)), + ("missing_dim_bboxes_dataset", pytest.raises(ValueError)), + ], +) +@pytest.mark.parametrize( + "kinematic_variable", + [ + "displacement", + "velocity", + "acceleration", + "speed", + ], +) +def test_kinematics_with_invalid_dataset( + invalid_dataset, + expected_exception, + kinematic_variable, + request, +): + """Test kinematics computation with an invalid dataset.""" + with expected_exception: + position = request.getfixturevalue(invalid_dataset).position + getattr(kinematics, f"compute_{kinematic_variable}")(position) + + +@pytest.mark.parametrize("order", [0, -1, 1.0, "1"]) +def test_approximate_derivative_with_invalid_order(order): + """Test that an error is raised when the order is non-positive.""" + data = np.arange(10) + expected_exception = ValueError if isinstance(order, int) else TypeError + with pytest.raises(expected_exception): + kinematics.compute_time_derivative(data, order=order) + + +time_points_value_error = pytest.raises( + ValueError, + match="At least 2 time points are required to compute path length", +) + + +@pytest.mark.parametrize( + "start, stop, expected_exception", + [ + # full time ranges + (None, None, does_not_raise()), + (0, None, does_not_raise()), + (0, 9, does_not_raise()), + (0, 10, does_not_raise()), # xarray.sel will truncate to 0, 9 + (-1, 9, does_not_raise()), # xarray.sel will truncate to 0, 9 + # partial time ranges + (1, 8, does_not_raise()), + (1.5, 8.5, does_not_raise()), + (2, None, does_not_raise()), + # Empty time ranges + (9, 0, time_points_value_error), # start > stop + ("text", 9, time_points_value_error), # invalid start type + # Time range too short + (0, 0.5, time_points_value_error), + ], +) +def test_path_length_across_time_ranges( + valid_poses_dataset_uniform_linear_motion, + start, + stop, + expected_exception, +): + """Test path length computation for a uniform linear motion case, + across different time ranges. + + The test dataset ``valid_poses_dataset_uniform_linear_motion`` + contains 2 individuals ("id_0" and "id_1"), moving + along x=y and x=-y lines, respectively, at a constant velocity. + At each frame they cover a distance of sqrt(2) in x-y space, so in total + we expect a path length of sqrt(2) * num_segments, where num_segments is + the number of selected frames minus 1. + """ + position = valid_poses_dataset_uniform_linear_motion.position + with expected_exception: + path_length = kinematics.compute_path_length( + position, start=start, stop=stop + ) + + # Expected number of segments (displacements) in selected time range + num_segments = 9 # full time range: 10 frames - 1 + start = max(0, start) if start is not None else 0 + stop = min(9, stop) if stop is not None else 9 + if start is not None: + num_segments -= np.ceil(max(0, start)) + if stop is not None: + stop = min(9, stop) + num_segments -= 9 - np.floor(min(9, stop)) + + expected_path_length = xr.DataArray( + np.ones((2, 3)) * np.sqrt(2) * num_segments, + dims=["individuals", "keypoints"], + coords={ + "individuals": position.coords["individuals"], + "keypoints": position.coords["keypoints"], + }, + ) + xr.testing.assert_allclose(path_length, expected_path_length) + + +@pytest.mark.parametrize( + "nan_policy, expected_path_lengths_id_1, expected_exception", + [ + ( + "ffill", + np.array([np.sqrt(2) * 8, np.sqrt(2) * 9, np.nan]), + does_not_raise(), + ), + ( + "scale", + np.array([np.sqrt(2) * 9, np.sqrt(2) * 9, np.nan]), + does_not_raise(), + ), + ( + "invalid", # invalid value for nan_policy + np.zeros(3), + pytest.raises(ValueError, match="Invalid value for nan_policy"), + ), + ], +) +def test_path_length_with_nans( + valid_poses_dataset_uniform_linear_motion_with_nans, + nan_policy, + expected_path_lengths_id_1, + expected_exception, +): + """Test path length computation for a uniform linear motion case, + with varying number of missing values per individual and keypoint. + + The test dataset ``valid_poses_dataset_uniform_linear_motion_with_nans`` + contains 2 individuals ("id_0" and "id_1"), moving + along x=y and x=-y lines, respectively, at a constant velocity. + At each frame they cover a distance of sqrt(2) in x-y space. + + Individual "id_1" has some missing values per keypoint: + - "centroid" is missing a value on the very first frame + - "left" is missing 5 values in middle frames (not at the edges) + - "right" is missing values in all frames + + Individual "id_0" has no missing values. + + Because the underlying motion is uniform linear, the "scale" policy should + perfectly restore the path length for individual "id_1" to its true value. + The "ffill" policy should do likewise if frames are missing in the middle, + but will not "correct" for missing values at the edges. + """ + position = valid_poses_dataset_uniform_linear_motion_with_nans.position + with expected_exception: + path_length = kinematics.compute_path_length( + position, + nan_policy=nan_policy, + ) + # Get path_length for individual "id_1" as a numpy array + path_length_id_1 = path_length.sel(individuals="id_1").values + # Check them against the expected values + np.testing.assert_allclose( + path_length_id_1, expected_path_lengths_id_1 ) - with pytest.raises(expected_exception): - kinematics._compute_approximate_derivative(data, order=order) + + +@pytest.mark.parametrize( + "nan_warn_threshold, expected_exception", + [ + (1, does_not_raise()), + (0.2, does_not_raise()), + (-1, pytest.raises(ValueError, match="between 0 and 1")), + ], +) +def test_path_length_warns_about_nans( + valid_poses_dataset_uniform_linear_motion_with_nans, + nan_warn_threshold, + expected_exception, + caplog, +): + """Test that a warning is raised when the number of missing values + exceeds a given threshold. + + See the docstring of ``test_path_length_with_nans`` for details + about what's in the dataset. + """ + position = valid_poses_dataset_uniform_linear_motion_with_nans.position + with expected_exception: + kinematics.compute_path_length( + position, nan_warn_threshold=nan_warn_threshold + ) + + if (nan_warn_threshold > 0.1) and (nan_warn_threshold < 0.5): + # Make sure that a warning was emitted + assert caplog.records[0].levelname == "WARNING" + assert "The result may be unreliable" in caplog.records[0].message + # Make sure that the NaN report only mentions + # the individual and keypoint that violate the threshold + assert caplog.records[1].levelname == "INFO" + assert "Individual: id_1" in caplog.records[1].message + assert "Individual: id_2" not in caplog.records[1].message + assert "left: 5/10 (50.0%)" in caplog.records[1].message + assert "right: 10/10 (100.0%)" in caplog.records[1].message + assert "centroid" not in caplog.records[1].message + + +@pytest.fixture +def valid_data_array_for_forward_vector(): + """Return a position data array for an individual with 3 keypoints + (left ear, right ear and nose), tracked for 4 frames, in x-y space. + """ + time = [0, 1, 2, 3] + individuals = ["individual_0"] + keypoints = ["left_ear", "right_ear", "nose"] + space = ["x", "y"] + + ds = xr.DataArray( + [ + [[[1, 0], [-1, 0], [0, -1]]], # time 0 + [[[0, 1], [0, -1], [1, 0]]], # time 1 + [[[-1, 0], [1, 0], [0, 1]]], # time 2 + [[[0, -1], [0, 1], [-1, 0]]], # time 3 + ], + dims=["time", "individuals", "keypoints", "space"], + coords={ + "time": time, + "individuals": individuals, + "keypoints": keypoints, + "space": space, + }, + ) + return ds + + +@pytest.fixture +def invalid_input_type_for_forward_vector(valid_data_array_for_forward_vector): + """Return a numpy array of position values by individual, per keypoint, + over time. + """ + return valid_data_array_for_forward_vector.values + + +@pytest.fixture +def invalid_dimensions_for_forward_vector(valid_data_array_for_forward_vector): + """Return a position DataArray in which the ``keypoints`` dimension has + been dropped. + """ + return valid_data_array_for_forward_vector.sel(keypoints="nose", drop=True) + + +@pytest.fixture +def invalid_spatial_dimensions_for_forward_vector( + valid_data_array_for_forward_vector, +): + """Return a position DataArray containing three spatial dimensions.""" + dataarray_3d = valid_data_array_for_forward_vector.pad( + space=(0, 1), constant_values=0 + ) + return dataarray_3d.assign_coords(space=["x", "y", "z"]) + + +@pytest.fixture +def valid_data_array_for_forward_vector_with_nans( + valid_data_array_for_forward_vector, +): + """Return a position DataArray where position values are NaN for the + ``left_ear`` keypoint at time ``1``. + """ + nan_dataarray = valid_data_array_for_forward_vector.where( + (valid_data_array_for_forward_vector.time != 1) + | (valid_data_array_for_forward_vector.keypoints != "left_ear") + ) + return nan_dataarray + + +def test_compute_forward_vector(valid_data_array_for_forward_vector): + """Test that the correct output forward direction vectors + are computed from a valid mock dataset. + """ + forward_vector = kinematics.compute_forward_vector( + valid_data_array_for_forward_vector, + "left_ear", + "right_ear", + camera_view="bottom_up", + ) + forward_vector_flipped = kinematics.compute_forward_vector( + valid_data_array_for_forward_vector, + "left_ear", + "right_ear", + camera_view="top_down", + ) + head_vector = kinematics.compute_head_direction_vector( + valid_data_array_for_forward_vector, + "left_ear", + "right_ear", + camera_view="bottom_up", + ) + known_vectors = np.array([[[0, -1]], [[1, 0]], [[0, 1]], [[-1, 0]]]) + + assert ( + isinstance(forward_vector, xr.DataArray) + and ("space" in forward_vector.dims) + and ("keypoints" not in forward_vector.dims) + ) + assert np.equal(forward_vector.values, known_vectors).all() + assert np.equal(forward_vector_flipped.values, known_vectors * -1).all() + assert head_vector.equals(forward_vector) + + +@pytest.mark.parametrize( + "input_data, expected_error, expected_match_str, keypoints", + [ + ( + "invalid_input_type_for_forward_vector", + TypeError, + "must be an xarray.DataArray", + ["left_ear", "right_ear"], + ), + ( + "invalid_dimensions_for_forward_vector", + ValueError, + "Input data must contain ['keypoints']", + ["left_ear", "right_ear"], + ), + ( + "invalid_spatial_dimensions_for_forward_vector", + ValueError, + "must have exactly 2 spatial dimensions", + ["left_ear", "right_ear"], + ), + ( + "valid_data_array_for_forward_vector", + ValueError, + "keypoints may not be identical", + ["left_ear", "left_ear"], + ), + ], +) +def test_compute_forward_vector_with_invalid_input( + input_data, keypoints, expected_error, expected_match_str, request +): + """Test that ``compute_forward_vector`` catches errors + correctly when passed invalid inputs. + """ + # Get fixture + input_data = request.getfixturevalue(input_data) + + # Catch error + with pytest.raises(expected_error, match=re.escape(expected_match_str)): + kinematics.compute_forward_vector( + input_data, keypoints[0], keypoints[1] + ) + + +def test_nan_behavior_forward_vector( + valid_data_array_for_forward_vector_with_nans, +): + """Test that ``compute_forward_vector()`` generates the + expected output for a valid input DataArray containing ``NaN`` + position values at a single time (``1``) and keypoint + (``left_ear``). + """ + forward_vector = kinematics.compute_forward_vector( + valid_data_array_for_forward_vector_with_nans, "left_ear", "right_ear" + ) + assert ( + np.isnan(forward_vector.values[1, 0, :]).all() + and not np.isnan(forward_vector.values[[0, 2, 3], 0, :]).any() + ) + + +@pytest.mark.parametrize( + "dim, expected_data", + [ + ( + "individuals", + np.array( + [ + [ + [0.0, 1.0, 1.0], + [1.0, np.sqrt(2), 0.0], + [1.0, 2.0, np.sqrt(2)], + ], + [ + [2.0, np.sqrt(5), 1.0], + [3.0, np.sqrt(10), 2.0], + [np.sqrt(5), np.sqrt(8), np.sqrt(2)], + ], + ] + ), + ), + ( + "keypoints", + np.array( + [[[1.0, 1.0], [1.0, 1.0]], [[1.0, np.sqrt(5)], [3.0, 1.0]]] + ), + ), + ], +) +def test_cdist_with_known_values( + dim, expected_data, valid_poses_dataset_uniform_linear_motion +): + """Test the computation of pairwise distances with known values.""" + labels_dim = "keypoints" if dim == "individuals" else "individuals" + input_dataarray = valid_poses_dataset_uniform_linear_motion.position.sel( + time=slice(0, 1) + ) # Use only the first two frames for simplicity + pairs = input_dataarray[dim].values[:2] + expected = xr.DataArray( + expected_data, + coords=[ + input_dataarray.time.values, + getattr(input_dataarray, labels_dim).values, + getattr(input_dataarray, labels_dim).values, + ], + dims=["time", pairs[0], pairs[1]], + ) + a = input_dataarray.sel({dim: pairs[0]}) + b = input_dataarray.sel({dim: pairs[1]}) + result = kinematics._cdist(a, b, dim) + xr.testing.assert_equal( + result, + expected, + ) + + +@pytest.mark.parametrize( + "valid_dataset", + [ + "valid_poses_dataset_uniform_linear_motion", + "valid_bboxes_dataset", + ], +) +@pytest.mark.parametrize( + "selection_fn", + [ + # individuals dim is scalar, + # poses: multiple keypoints + # bboxes: missing keypoints dim + # e.g. comparing 2 individuals from the same data array + lambda position: ( + position.isel(individuals=0), + position.isel(individuals=1), + ), + # individuals dim is 1D + # poses: multiple keypoints + # bboxes: missing keypoints dim + # e.g. comparing 2 single-individual data arrays + lambda position: ( + position.where( + position.individuals == position.individuals[0], drop=True + ).squeeze(), + position.where( + position.individuals == position.individuals[1], drop=True + ).squeeze(), + ), + # both individuals and keypoints dims are scalar (poses only) + # e.g. comparing 2 individuals from the same data array, + # at the same keypoint + lambda position: ( + position.isel(individuals=0, keypoints=0), + position.isel(individuals=1, keypoints=0), + ), + # individuals dim is scalar, keypoints dim is 1D (poses only) + # e.g. comparing 2 single-individual, single-keypoint data arrays + lambda position: ( + position.where( + position.keypoints == position.keypoints[0], drop=True + ).isel(individuals=0), + position.where( + position.keypoints == position.keypoints[0], drop=True + ).isel(individuals=1), + ), + ], + ids=[ + "dim_has_ndim_0", + "dim_has_ndim_1", + "labels_dim_has_ndim_0", + "labels_dim_has_ndim_1", + ], +) +def test_cdist_with_single_dim_inputs(valid_dataset, selection_fn, request): + """Test that the pairwise distances data array is successfully + returned regardless of whether the input DataArrays have + ``dim`` ("individuals") and ``labels_dim`` ("keypoints") + being either scalar (ndim=0) or 1D (ndim=1), + or if ``labels_dim`` is missing. + """ + if request.node.callspec.id not in [ + "labels_dim_has_ndim_0-valid_bboxes_dataset", + "labels_dim_has_ndim_1-valid_bboxes_dataset", + ]: # Skip tests with keypoints dim for bboxes + valid_dataset = request.getfixturevalue(valid_dataset) + position = valid_dataset.position + a, b = selection_fn(position) + assert isinstance(kinematics._cdist(a, b, "individuals"), xr.DataArray) + + +@pytest.mark.parametrize( + "dim, pairs, expected_data_vars", + [ + ("individuals", {"id_1": ["id_2"]}, None), # list input + ("individuals", {"id_1": "id_2"}, None), # string input + ( + "individuals", + {"id_1": ["id_2"], "id_2": "id_1"}, + [("id_1", "id_2"), ("id_2", "id_1")], + ), + ("individuals", "all", None), # all pairs + ("keypoints", {"centroid": ["left"]}, None), # list input + ("keypoints", {"centroid": "left"}, None), # string input + ( + "keypoints", + {"centroid": ["left"], "left": "right"}, + [("centroid", "left"), ("left", "right")], + ), + ( + "keypoints", + "all", + [("centroid", "left"), ("centroid", "right"), ("left", "right")], + ), # all pairs + ], +) +def test_compute_pairwise_distances_with_valid_pairs( + valid_poses_dataset_uniform_linear_motion, dim, pairs, expected_data_vars +): + """Test that the expected pairwise distances are computed + for valid ``pairs`` inputs. + """ + result = kinematics.compute_pairwise_distances( + valid_poses_dataset_uniform_linear_motion.position, dim, pairs + ) + if isinstance(result, dict): + expected_data_vars = [ + f"dist_{pair[0]}_{pair[1]}" for pair in expected_data_vars + ] + assert set(result.keys()) == set(expected_data_vars) + else: # expect single DataArray + assert isinstance(result, xr.DataArray) + + +@pytest.mark.parametrize( + "ds, dim, pairs", + [ + ( + "valid_poses_dataset_uniform_linear_motion", + "invalid_dim", + {"id_1": "id_2"}, + ), # invalid dim + ( + "valid_poses_dataset_uniform_linear_motion", + "keypoints", + "invalid_string", + ), # invalid pairs + ( + "valid_poses_dataset_uniform_linear_motion", + "individuals", + {}, + ), # empty pairs + ("missing_dim_poses_dataset", "keypoints", "all"), # invalid dataset + ( + "missing_dim_bboxes_dataset", + "individuals", + "all", + ), # invalid dataset + ], +) +def test_compute_pairwise_distances_with_invalid_input( + ds, dim, pairs, request +): + """Test that an error is raised for invalid inputs.""" + with pytest.raises(ValueError): + kinematics.compute_pairwise_distances( + request.getfixturevalue(ds).position, dim, pairs + ) + + # @pytest.mark.parametrize("ds, expected_exception", kinematic_test_params) + # @pytest.fixture(mult) + # def test_multiview(self): + # ds = request.getfixturevalue("multi_view_dataset") + # pass + # result = kinematics.compute_displacement(multi_view_dataset.position) diff --git a/tests/test_unit/test_load_bboxes.py b/tests/test_unit/test_load_bboxes.py new file mode 100644 index 00000000..2f80459d --- /dev/null +++ b/tests/test_unit/test_load_bboxes.py @@ -0,0 +1,472 @@ +"""Test suite for the load_bboxes module.""" + +import ast +from unittest.mock import patch + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from movement.io import load_bboxes +from movement.validators.datasets import ValidBboxesDataset + + +@pytest.fixture() +def via_tracks_file(): + """Return the file path for a VIA tracks .csv file.""" + via_sample_file_name = "VIA_multiple-crabs_5-frames_labels.csv" + return pytest.DATA_PATHS.get(via_sample_file_name) + + +@pytest.fixture() +def valid_from_numpy_inputs_required_arrays(): + """Return a dictionary with valid numpy arrays for the `from_numpy()` + loader, excluding the optional `frame_array`. + """ + n_frames = 5 + n_individuals = 86 + n_space = 2 + individual_names_array = np.arange(n_individuals).reshape(-1, 1) + + rng = np.random.default_rng(seed=42) + + return { + "position_array": rng.random((n_frames, n_individuals, n_space)), + "shape_array": rng.random((n_frames, n_individuals, n_space)), + "confidence_array": rng.random((n_frames, n_individuals)), + "individual_names": [ + f"id_{id}" for id in individual_names_array.squeeze() + ], + } + + +@pytest.fixture() +def valid_from_numpy_inputs_all_arrays( + valid_from_numpy_inputs_required_arrays, +): + """Return a dictionary with valid numpy arrays for the from_numpy() loader, + including a `frame_array` that ranges from frame 1 to 5. + """ + n_frames = valid_from_numpy_inputs_required_arrays["position_array"].shape[ + 0 + ] + first_frame_number = 1 # should match sample file + + valid_from_numpy_inputs_required_arrays["frame_array"] = np.arange( + first_frame_number, first_frame_number + n_frames + ).reshape(-1, 1) + + return valid_from_numpy_inputs_required_arrays + + +@pytest.fixture() +def df_input_via_tracks_small(via_tracks_file): + """Return the first 3 rows of the VIA tracks .csv file as a dataframe.""" + df = pd.read_csv(via_tracks_file, sep=",", header=0) + return df.loc[:2, :] + + +@pytest.fixture() +def df_input_via_tracks_small_with_confidence(df_input_via_tracks_small): + """Return a dataframe with the first three rows of the VIA tracks .csv file + and add confidence values to the bounding boxes. + """ + df = update_attribute_column( + df_input=df_input_via_tracks_small, + attribute_column_name="region_attributes", + dict_to_append={"confidence": "0.5"}, + ) + + return df + + +@pytest.fixture() +def df_input_via_tracks_small_with_frame_number(df_input_via_tracks_small): + """Return a dataframe with the first three rows of the VIA tracks .csv file + and add frame number values to the bounding boxes. + """ + df = update_attribute_column( + df_input=df_input_via_tracks_small, + attribute_column_name="file_attributes", + dict_to_append={"frame": "1"}, + ) + + return df + + +def update_attribute_column(df_input, attribute_column_name, dict_to_append): + """Update an attributes column in the dataframe.""" + # copy the dataframe + df = df_input.copy() + + # get the attributes column and convert to dict + attributes_dicts = [ast.literal_eval(d) for d in df[attribute_column_name]] + + # update the dict + for d in attributes_dicts: + d.update(dict_to_append) + + # update the region_attributes column in the dataframe + df[attribute_column_name] = [str(d) for d in attributes_dicts] + return df + + +def assert_dataset( + dataset, file_path=None, expected_source_software=None, expected_fps=None +): + """Assert that the dataset is a proper ``movement`` Dataset.""" + assert isinstance(dataset, xr.Dataset) + + # Expected variables are present and of right shape/type + for var in ["position", "shape", "confidence"]: + assert var in dataset.data_vars + assert isinstance(dataset[var], xr.DataArray) + assert dataset.position.ndim == 3 + assert dataset.shape.ndim == 3 + assert dataset.confidence.shape == dataset.position.shape[:-1] + + # Check the dims and coords + DIM_NAMES = ValidBboxesDataset.DIM_NAMES + assert all([i in dataset.dims for i in DIM_NAMES]) + for d, dim in enumerate(DIM_NAMES[1:]): + assert dataset.sizes[dim] == dataset.position.shape[d + 1] + assert all([isinstance(s, str) for s in dataset.coords[dim].values]) + assert all([i in dataset.coords["space"] for i in ["x", "y"]]) + + # Check the metadata attributes + assert ( + dataset.source_file is None + if file_path is None + else dataset.source_file == file_path.as_posix() + ) + assert ( + dataset.source_software is None + if expected_source_software is None + else dataset.source_software == expected_source_software + ) + assert ( + dataset.fps is None + if expected_fps is None + else dataset.fps == expected_fps + ) + + +def assert_time_coordinates(ds, fps, start_frame): + """Assert that the time coordinates are as expected, depending on + fps value and start_frame. + """ + # scale time coordinates with 1/fps if provided + scale = 1 / fps if fps else 1 + + # assert numpy array of time coordinates + np.testing.assert_allclose( + ds.coords["time"].data, + np.array( + [ + f * scale + for f in range( + start_frame, len(ds.coords["time"].data) + start_frame + ) + ] + ), + ) + + +@pytest.mark.parametrize("source_software", ["Unknown", "VIA-tracks"]) +@pytest.mark.parametrize("fps", [None, 30, 60.0]) +@pytest.mark.parametrize("use_frame_numbers_from_file", [True, False]) +def test_from_file(source_software, fps, use_frame_numbers_from_file): + """Test that the from_file() function delegates to the correct + loader function according to the source_software. + """ + software_to_loader = { + "VIA-tracks": "movement.io.load_bboxes.from_via_tracks_file", + } + + if source_software == "Unknown": + with pytest.raises(ValueError, match="Unsupported source"): + load_bboxes.from_file( + "some_file", + source_software, + fps, + use_frame_numbers_from_file=use_frame_numbers_from_file, + ) + else: + with patch(software_to_loader[source_software]) as mock_loader: + load_bboxes.from_file( + "some_file", + source_software, + fps, + use_frame_numbers_from_file=use_frame_numbers_from_file, + ) + mock_loader.assert_called_with( + "some_file", + fps, + use_frame_numbers_from_file=use_frame_numbers_from_file, + ) + + +@pytest.mark.parametrize("fps", [None, 30, 60.0]) +@pytest.mark.parametrize("use_frame_numbers_from_file", [True, False]) +def test_from_via_tracks_file( + via_tracks_file, fps, use_frame_numbers_from_file +): + """Test that loading tracked bounding box data from + a valid VIA tracks .csv file returns a proper Dataset + and that the time coordinates are as expected. + """ + # run general dataset checks + ds = load_bboxes.from_via_tracks_file( + via_tracks_file, fps, use_frame_numbers_from_file + ) + assert_dataset(ds, via_tracks_file, "VIA-tracks", fps) + + # check time coordinates are as expected + # in sample VIA tracks .csv file frame numbers start from 1 + start_frame = 1 if use_frame_numbers_from_file else 0 + assert_time_coordinates(ds, fps, start_frame) + + +@pytest.mark.parametrize( + "valid_from_numpy_inputs", + [ + "valid_from_numpy_inputs_required_arrays", + "valid_from_numpy_inputs_all_arrays", + ], +) +@pytest.mark.parametrize("fps", [None, 30, 60.0]) +@pytest.mark.parametrize("source_software", [None, "VIA-tracks"]) +def test_from_numpy(valid_from_numpy_inputs, fps, source_software, request): + """Test that loading bounding boxes trajectories from the input + numpy arrays returns a proper Dataset. + """ + # get the input arrays + from_numpy_inputs = request.getfixturevalue(valid_from_numpy_inputs) + + # run general dataset checks + ds = load_bboxes.from_numpy( + **from_numpy_inputs, + fps=fps, + source_software=source_software, + ) + assert_dataset( + ds, expected_source_software=source_software, expected_fps=fps + ) + + # check time coordinates are as expected + if "frame_array" in from_numpy_inputs: + start_frame = from_numpy_inputs["frame_array"][0, 0] + else: + start_frame = 0 + assert_time_coordinates(ds, fps, start_frame) + + +@pytest.mark.parametrize( + "via_column_name, list_keys, cast_fn, expected_attribute_array", + [ + ( + "file_attributes", + ["clip"], + int, + np.array([123] * 3), # .reshape(-1, 1), + ), + ( + "region_shape_attributes", + ["name"], + str, + np.array(["rect"] * 3), # .reshape(-1, 1), + ), + ( + "region_shape_attributes", + ["x", "y"], + float, + np.array( + [ + [526.2366942646654, 393.280914246804], + [2565, 468], + [759.6484377108334, 136.60946673708338], + ] + ).reshape(-1, 2), + ), + ( + "region_shape_attributes", + ["width", "height"], + float, + np.array([[46, 38], [41, 30], [29, 25]]).reshape(-1, 2), + ), + ( + "region_attributes", + ["track"], + int, + np.array([71, 70, 69]), # .reshape(-1, 1), + ), + ], +) +def test_via_attribute_column_to_numpy( + df_input_via_tracks_small, + via_column_name, + list_keys, + cast_fn, + expected_attribute_array, +): + """Test that the function correctly extracts the desired data from the VIA + attributes. + """ + attribute_array = load_bboxes._via_attribute_column_to_numpy( + df=df_input_via_tracks_small, + via_column_name=via_column_name, + list_keys=list_keys, + cast_fn=cast_fn, + ) + + assert np.array_equal(attribute_array, expected_attribute_array) + + +@pytest.mark.parametrize( + "df_input, expected_array", + [ + ("df_input_via_tracks_small", np.full((3,), np.nan)), + ( + "df_input_via_tracks_small_with_confidence", + np.array([0.5, 0.5, 0.5]), + ), + ], +) +def test_extract_confidence_from_via_tracks_df( + df_input, expected_array, request +): + """Test that the function correctly extracts the confidence values from + the VIA dataframe. + """ + df = request.getfixturevalue(df_input) + confidence_array = load_bboxes._extract_confidence_from_via_tracks_df(df) + + assert np.array_equal(confidence_array, expected_array, equal_nan=True) + + +@pytest.mark.parametrize( + "df_input, expected_array", + [ + ( + "df_input_via_tracks_small", + np.ones((3,)), + ), # extract from filename + ( + "df_input_via_tracks_small_with_frame_number", + np.array([1, 1, 1]), + ), # extract from file_attributes + ], +) +def test_extract_frame_number_from_via_tracks_df( + df_input, expected_array, request +): + """Test that the function correctly extracts the frame number values from + the VIA dataframe. + """ + df = request.getfixturevalue(df_input) + frame_array = load_bboxes._extract_frame_number_from_via_tracks_df(df) + + assert np.array_equal(frame_array, expected_array) + + +@pytest.mark.parametrize( + "fps, expected_fps, expected_time_unit", + [ + (None, None, "frames"), + (-5, None, "frames"), + (0, None, "frames"), + (30, 30.0, "seconds"), + (60.0, 60.0, "seconds"), + ], +) +@pytest.mark.parametrize("use_frame_numbers_from_file", [True, False]) +def test_fps_and_time_coords( + via_tracks_file, + fps, + expected_fps, + expected_time_unit, + use_frame_numbers_from_file, +): + """Test that fps conversion is as expected and time coordinates are set + according to the expected fps. + """ + ds = load_bboxes.from_via_tracks_file( + via_tracks_file, + fps=fps, + use_frame_numbers_from_file=use_frame_numbers_from_file, + ) + + # load dataset with frame numbers from file + ds_in_frames_from_file = load_bboxes.from_via_tracks_file( + via_tracks_file, + fps=None, + use_frame_numbers_from_file=True, + ) + + # check time unit + assert ds.time_unit == expected_time_unit + + # check fps is as expected + if expected_fps is None: + assert ds.fps is expected_fps + else: + assert ds.fps == expected_fps + + # check time coordinates + if use_frame_numbers_from_file: + start_frame = ds_in_frames_from_file.coords["time"].data[0] + else: + start_frame = 0 + assert_time_coordinates(ds, expected_fps, start_frame) + + +def test_df_from_via_tracks_file(via_tracks_file): + """Test that the helper function correctly reads the VIA tracks .csv file + as a dataframe. + """ + df = load_bboxes._df_from_via_tracks_file(via_tracks_file) + + assert isinstance(df, pd.DataFrame) + assert len(df.frame_number.unique()) == 5 + assert ( + df.shape[0] == len(df.ID.unique()) * 5 + ) # all individuals in all frames (even if nan) + assert list(df.columns) == [ + "ID", + "frame_number", + "x", + "y", + "w", + "h", + "confidence", + ] + + +def test_position_numpy_array_from_via_tracks_file(via_tracks_file): + """Test the extracted position array from the VIA tracks .csv file + represents the centroid of the bbox. + """ + # Extract numpy arrays from VIA tracks .csv file + bboxes_arrays = load_bboxes._numpy_arrays_from_via_tracks_file( + via_tracks_file + ) + + # Read VIA tracks .csv file as a dataframe + df = load_bboxes._df_from_via_tracks_file(via_tracks_file) + + # Compute centroid positions from the dataframe + # (go thru in the same order as ID array) + list_derived_centroids = [] + for id in bboxes_arrays["ID_array"]: + df_one_id = df[df["ID"] == id.item()] + centroid_position = np.array( + [df_one_id.x + df_one_id.w / 2, df_one_id.y + df_one_id.h / 2] + ).T # frames, xy + list_derived_centroids.append(centroid_position) + + # Compare to extracted position array + assert np.allclose( + bboxes_arrays["position_array"], # frames, individuals, xy + np.stack(list_derived_centroids, axis=1), + ) diff --git a/tests/test_unit/test_load_poses.py b/tests/test_unit/test_load_poses.py index f5e728bb..148d247f 100644 --- a/tests/test_unit/test_load_poses.py +++ b/tests/test_unit/test_load_poses.py @@ -4,12 +4,12 @@ import numpy as np import pytest import xarray as xr -from pytest import POSE_DATA_PATHS +from pytest import DATA_PATHS from sleap_io.io.slp import read_labels, write_labels from sleap_io.model.labels import LabeledFrame, Labels -from movement import MovementDataset from movement.io import load_poses +from movement.validators.datasets import ValidPosesDataset class TestLoadPoses: @@ -18,9 +18,7 @@ class TestLoadPoses: @pytest.fixture def sleap_slp_file_without_tracks(self, tmp_path): """Mock and return the path to a SLEAP .slp file without tracks.""" - sleap_file = POSE_DATA_PATHS.get( - "SLEAP_single-mouse_EPM.predictions.slp" - ) + sleap_file = DATA_PATHS.get("SLEAP_single-mouse_EPM.predictions.slp") labels = read_labels(sleap_file) file_path = tmp_path / "track_is_none.slp" lfs = [] @@ -48,7 +46,7 @@ def sleap_slp_file_without_tracks(self, tmp_path): @pytest.fixture def sleap_h5_file_without_tracks(self, tmp_path): """Mock and return the path to a SLEAP .h5 file without tracks.""" - sleap_file = POSE_DATA_PATHS.get("SLEAP_single-mouse_EPM.analysis.h5") + sleap_file = DATA_PATHS.get("SLEAP_single-mouse_EPM.analysis.h5") file_path = tmp_path / "track_is_none.h5" with h5py.File(sleap_file, "r") as f1, h5py.File(file_path, "w") as f2: for key in list(f1.keys()): @@ -80,7 +78,7 @@ def assert_dataset( assert dataset.position.ndim == 4 assert dataset.confidence.shape == dataset.position.shape[:-1] # Check the dims and coords - DIM_NAMES = MovementDataset.dim_names + DIM_NAMES = ValidPosesDataset.DIM_NAMES assert all([i in dataset.dims for i in DIM_NAMES]) for d, dim in enumerate(DIM_NAMES[1:]): assert dataset.sizes[dim] == dataset.position.shape[d + 1] @@ -120,7 +118,7 @@ def test_load_from_sleap_file_without_tracks( sleap_file_without_tracks ) ds_from_tracked = load_poses.from_sleap_file( - POSE_DATA_PATHS.get("SLEAP_single-mouse_EPM.analysis.h5") + DATA_PATHS.get("SLEAP_single-mouse_EPM.analysis.h5") ) # Check if the "individuals" coordinate matches # the assigned default "individuals_0" @@ -153,8 +151,8 @@ def test_load_from_sleap_slp_file_or_h5_file_returns_same( """Test that loading pose tracks from SLEAP .slp and .h5 files return the same Dataset. """ - slp_file_path = POSE_DATA_PATHS.get(slp_file) - h5_file_path = POSE_DATA_PATHS.get(h5_file) + slp_file_path = DATA_PATHS.get(slp_file) + h5_file_path = DATA_PATHS.get(h5_file) ds_from_slp = load_poses.from_sleap_file(slp_file_path) ds_from_h5 = load_poses.from_sleap_file(h5_file_path) xr.testing.assert_allclose(ds_from_h5, ds_from_slp) @@ -171,23 +169,28 @@ def test_load_from_dlc_file(self, file_name): """Test that loading pose tracks from valid DLC files returns a proper Dataset. """ - file_path = POSE_DATA_PATHS.get(file_name) + file_path = DATA_PATHS.get(file_name) ds = load_poses.from_dlc_file(file_path) self.assert_dataset(ds, file_path, "DeepLabCut") - def test_load_from_dlc_df(self, dlc_style_df): + @pytest.mark.parametrize( + "source_software", ["DeepLabCut", "LightningPose", None] + ) + def test_load_from_dlc_style_df(self, dlc_style_df, source_software): """Test that loading pose tracks from a valid DLC-style DataFrame returns a proper Dataset. """ - ds = load_poses.from_dlc_df(dlc_style_df) - self.assert_dataset(ds) + ds = load_poses.from_dlc_style_df( + dlc_style_df, source_software=source_software + ) + self.assert_dataset(ds, expected_source_software=source_software) def test_load_from_dlc_file_csv_or_h5_file_returns_same(self): """Test that loading pose tracks from DLC .csv and .h5 files return the same Dataset. """ - csv_file_path = POSE_DATA_PATHS.get("DLC_single-wasp.predictions.csv") - h5_file_path = POSE_DATA_PATHS.get("DLC_single-wasp.predictions.h5") + csv_file_path = DATA_PATHS.get("DLC_single-wasp.predictions.csv") + h5_file_path = DATA_PATHS.get("DLC_single-wasp.predictions.h5") ds_from_csv = load_poses.from_dlc_file(csv_file_path) ds_from_h5 = load_poses.from_dlc_file(h5_file_path) xr.testing.assert_allclose(ds_from_h5, ds_from_csv) @@ -205,7 +208,7 @@ def test_load_from_dlc_file_csv_or_h5_file_returns_same(self): def test_fps_and_time_coords(self, fps, expected_fps, expected_time_unit): """Test that time coordinates are set according to the provided fps.""" ds = load_poses.from_sleap_file( - POSE_DATA_PATHS.get("SLEAP_three-mice_Aeon_proofread.analysis.h5"), + DATA_PATHS.get("SLEAP_three-mice_Aeon_proofread.analysis.h5"), fps=fps, ) assert ds.time_unit == expected_time_unit @@ -229,7 +232,7 @@ def test_load_from_lp_file(self, file_name): """Test that loading pose tracks from valid LightningPose (LP) files returns a proper Dataset. """ - file_path = POSE_DATA_PATHS.get(file_name) + file_path = DATA_PATHS.get(file_name) ds = load_poses.from_lp_file(file_path) self.assert_dataset(ds, file_path, "LightningPose") @@ -238,7 +241,7 @@ def test_load_from_lp_or_dlc_file_returns_same(self): using either the `from_lp_file` or `from_dlc_file` function returns the same Dataset (except for the source_software). """ - file_path = POSE_DATA_PATHS.get("LP_mouse-face_AIND.predictions.csv") + file_path = DATA_PATHS.get("LP_mouse-face_AIND.predictions.csv") ds_drom_lp = load_poses.from_lp_file(file_path) ds_from_dlc = load_poses.from_dlc_file(file_path) xr.testing.assert_allclose(ds_from_dlc, ds_drom_lp) @@ -249,7 +252,7 @@ def test_load_multi_individual_from_lp_file_raises(self): """Test that loading a multi-individual .csv file using the `from_lp_file` function raises a ValueError. """ - file_path = POSE_DATA_PATHS.get("DLC_two-mice.predictions.csv") + file_path = DATA_PATHS.get("DLC_two-mice.predictions.csv") with pytest.raises(ValueError): load_poses.from_lp_file(file_path) @@ -274,3 +277,44 @@ def test_from_file_delegates_correctly(self, source_software, fps): with patch(software_to_loader[source_software]) as mock_loader: load_poses.from_file("some_file", source_software, fps) mock_loader.assert_called_with("some_file", fps) + + @pytest.mark.parametrize("source_software", [None, "SLEAP"]) + def test_from_numpy_valid( + self, + valid_position_array, + source_software, + ): + """Test that loading pose tracks from a multi-animal numpy array + with valid parameters returns a proper Dataset. + """ + valid_position = valid_position_array("multi_individual_array") + rng = np.random.default_rng(seed=42) + valid_confidence = rng.random(valid_position.shape[:-1]) + + ds = load_poses.from_numpy( + valid_position, + valid_confidence, + individual_names=["mouse1", "mouse2"], + keypoint_names=["snout", "tail"], + fps=None, + source_software=source_software, + ) + self.assert_dataset(ds, expected_source_software=source_software) + + def from_multiview_files(self): + """Test that the from_file() function delegates to the correct + loader function according to the source_software. + """ + view_names = ["view_0", "view_1"] + file_path_dict = { + view: DATA_PATHS.get("DLC_single-wasp.predictions.h5") + for view in view_names + } + + multi_view_ds = load_poses.from_multi_view( + file_path_dict, source_software="DeepLabCut" + ) + + assert isinstance(multi_view_ds, xr.Dataset) + assert "view" in multi_view_ds.dims + assert multi_view_ds.view.values.tolist() == view_names diff --git a/tests/test_unit/test_logging.py b/tests/test_unit/test_logging.py index 40e4415d..348a3687 100644 --- a/tests/test_unit/test_logging.py +++ b/tests/test_unit/test_logging.py @@ -1,8 +1,9 @@ import logging import pytest +import xarray as xr -from movement.logging import log_error, log_warning +from movement.utils.logging import log_error, log_to_attrs, log_warning log_messages = { "DEBUG": "This is a debug message", @@ -43,3 +44,43 @@ def test_log_warning(caplog): log_warning("This is a test warning") assert caplog.records[0].message == "This is a test warning" assert caplog.records[0].levelname == "WARNING" + + +@pytest.mark.parametrize( + "input_data", + [ + "valid_poses_dataset", + "valid_bboxes_dataset", + ], +) +@pytest.mark.parametrize( + "selector_fn, expected_selector_type", + [ + (lambda ds: ds, xr.Dataset), # take full dataset + (lambda ds: ds.position, xr.DataArray), # take position data array + ], +) +def test_log_to_attrs( + input_data, selector_fn, expected_selector_type, request +): + """Test that the ``log_to_attrs()`` decorator appends + log entries to the dataset's or the data array's ``log`` + attribute and check that ``attrs`` contains all the expected values. + """ + + # a fake operation on the dataset to log + @log_to_attrs + def fake_func(data, arg, kwarg=None): + return data + + # apply operation to dataset or data array + dataset = request.getfixturevalue(input_data) + input_data = selector_fn(dataset) + output_data = fake_func(input_data, "test1", kwarg="test2") + + # check the log in the dataset is as expected + assert isinstance(output_data, expected_selector_type) + assert "log" in output_data.attrs + assert output_data.attrs["log"][0]["operation"] == "fake_func" + assert output_data.attrs["log"][0]["arg_1"] == "test1" + assert output_data.attrs["log"][0]["kwarg"] == "test2" diff --git a/tests/test_unit/test_move_accessor.py b/tests/test_unit/test_move_accessor.py deleted file mode 100644 index ae6b4962..00000000 --- a/tests/test_unit/test_move_accessor.py +++ /dev/null @@ -1,42 +0,0 @@ -import pytest -import xarray as xr - - -class TestMovementDataset: - """Test suite for the MovementDataset class.""" - - def test_compute_kinematics_with_valid_dataset( - self, valid_poses_dataset, kinematic_property - ): - """Test that computing a kinematic property of a valid - pose dataset via accessor methods returns an instance of - xr.DataArray. - """ - result = getattr( - valid_poses_dataset.move, f"compute_{kinematic_property}" - )() - assert isinstance(result, xr.DataArray) - - def test_compute_kinematics_with_invalid_dataset( - self, invalid_poses_dataset, kinematic_property - ): - """Test that computing a kinematic property of an invalid - pose dataset via accessor methods raises the appropriate error. - """ - expected_exception = ( - ValueError - if isinstance(invalid_poses_dataset, xr.Dataset) - else AttributeError - ) - with pytest.raises(expected_exception): - getattr( - invalid_poses_dataset.move, f"compute_{kinematic_property}" - )() - - @pytest.mark.parametrize( - "method", ["compute_invalid_property", "do_something"] - ) - def test_invalid_compute(self, valid_poses_dataset, method): - """Test that invalid accessor method calls raise an AttributeError.""" - with pytest.raises(AttributeError): - getattr(valid_poses_dataset.move, method)() diff --git a/tests/test_unit/test_reports.py b/tests/test_unit/test_reports.py new file mode 100644 index 00000000..79c3bc89 --- /dev/null +++ b/tests/test_unit/test_reports.py @@ -0,0 +1,129 @@ +import pytest + +from movement.utils.reports import report_nan_values + + +@pytest.mark.parametrize( + "valid_dataset", + [ + "valid_poses_dataset", + "valid_bboxes_dataset", + "valid_poses_dataset_with_nan", + "valid_bboxes_dataset_with_nan", + ], +) +@pytest.mark.parametrize( + "data_selection, list_expected_individuals_indices", + [ + (lambda ds: ds.position, [0, 1]), # full position data array + ( + lambda ds: ds.position.isel(individuals=0), + [0], + ), # position of individual 0 only + ], +) +def test_report_nan_values_in_position_selecting_individual( + valid_dataset, + data_selection, + list_expected_individuals_indices, + request, +): + """Test that the nan-value reporting function handles position data + with specific ``individuals`` , and that the data array name (position) + and only the relevant individuals are included in the report. + """ + # extract relevant position data + input_dataset = request.getfixturevalue(valid_dataset) + output_data_array = data_selection(input_dataset) + + # produce report + report_str = report_nan_values(output_data_array) + + # check report of nan values includes name of data array + assert output_data_array.name in report_str + + # check report of nan values includes selected individuals only + list_expected_individuals = [ + input_dataset["individuals"][idx].item() + for idx in list_expected_individuals_indices + ] + list_not_expected_individuals = [ + indiv.item() + for indiv in input_dataset["individuals"] + if indiv.item() not in list_expected_individuals + ] + assert all([ind in report_str for ind in list_expected_individuals]) + assert all( + [ind not in report_str for ind in list_not_expected_individuals] + ) + + +@pytest.mark.parametrize( + "valid_dataset", + [ + "valid_poses_dataset", + "valid_poses_dataset_with_nan", + ], +) +@pytest.mark.parametrize( + "data_selection, list_expected_keypoints, list_expected_individuals", + [ + ( + lambda ds: ds.position, + ["key1", "key2"], + ["ind1", "ind2"], + ), # Report nans in position for all keypoints and individuals + ( + lambda ds: ds.position.sel(keypoints="key1"), + [], + ["ind1", "ind2"], + ), # Report nans in position for keypoint "key1", for all individuals + # Note: if only one keypoint exists, it is not explicitly reported + ( + lambda ds: ds.position.sel(individuals="ind1", keypoints="key1"), + [], + ["ind1"], + ), # Report nans in position for individual "ind1" and keypoint "key1" + # Note: if only one keypoint exists, it is not explicitly reported + ], +) +def test_report_nan_values_in_position_selecting_keypoint( + valid_dataset, + data_selection, + list_expected_keypoints, + list_expected_individuals, + request, +): + """Test that the nan-value reporting function handles position data + with specific ``keypoints`` , and that the data array name (position) + and only the relevant keypoints are included in the report. + """ + # extract relevant position data + input_dataset = request.getfixturevalue(valid_dataset) + output_data_array = data_selection(input_dataset) + + # produce report + report_str = report_nan_values(output_data_array) + + # check report of nan values includes name of data array + assert output_data_array.name in report_str + + # check report of nan values includes only selected keypoints + list_not_expected_keypoints = [ + indiv.item() + for indiv in input_dataset["keypoints"] + if indiv.item() not in list_expected_keypoints + ] + assert all([kpt in report_str for kpt in list_expected_keypoints]) + assert all([kpt not in report_str for kpt in list_not_expected_keypoints]) + + # check report of nan values includes selected individuals only + list_not_expected_individuals = [ + indiv.item() + for indiv in input_dataset["individuals"] + if indiv.item() not in list_expected_individuals + ] + assert all([ind in report_str for ind in list_expected_individuals]) + assert all( + [ind not in report_str for ind in list_not_expected_individuals] + ) diff --git a/tests/test_unit/test_sample_data.py b/tests/test_unit/test_sample_data.py index 50967d62..ad408126 100644 --- a/tests/test_unit/test_sample_data.py +++ b/tests/test_unit/test_sample_data.py @@ -16,10 +16,10 @@ def valid_sample_datasets(): respective fps values, and associated frame and video file names. """ return { - "SLEAP_single-mouse_EPM.analysis.h5": { - "fps": 30, - "frame_file": "single-mouse_EPM_frame-20sec.png", - "video_file": "single-mouse_EPM_video.mp4", + "SLEAP_three-mice_Aeon_proofread.analysis.h5": { + "fps": 50, + "frame_file": "three-mice_Aeon_frame-5sec.png", + "video_file": "three-mice_Aeon_video.avi", }, "DLC_single-wasp.predictions.h5": { "fps": 40, @@ -31,6 +31,11 @@ def valid_sample_datasets(): "frame_file": None, "video_file": None, }, + "VIA_multiple-crabs_5-frames_labels.csv": { + "fps": None, + "frame_file": None, + "video_file": None, + }, } @@ -38,7 +43,9 @@ def validate_metadata(metadata: dict[str, dict]) -> None: """Assert that the metadata is in the expected format.""" metadata_fields = [ "sha256sum", + "type", "source_software", + "type", "fps", "species", "number_of_individuals", @@ -59,9 +66,9 @@ def validate_metadata(metadata: dict[str, dict]) -> None: ), f"Expected metadata values to be dicts. {check_yaml_msg}" assert all( set(val.keys()) == set(metadata_fields) for val in metadata.values() - ), f"Found issues with the names of medatada fields. {check_yaml_msg}" + ), f"Found issues with the names of metadata fields. {check_yaml_msg}" - # check that metadata keys (pose file names) are unique + # check that metadata keys (file names) are unique assert len(metadata.keys()) == len(set(metadata.keys())) # check that the first 2 fields are present and are strings @@ -120,18 +127,24 @@ def test_list_datasets(valid_sample_datasets): assert all(file in list_datasets() for file in valid_sample_datasets) -def test_fetch_dataset(valid_sample_datasets): +@pytest.mark.parametrize("with_video", [True, False]) +def test_fetch_dataset(valid_sample_datasets, with_video): # test with valid files for sample_name, sample in valid_sample_datasets.items(): - ds = fetch_dataset(sample_name) + ds = fetch_dataset(sample_name, with_video=with_video) assert isinstance(ds, Dataset) assert ds.attrs["fps"] == sample["fps"] if sample["frame_file"]: assert ds.attrs["frame_path"].name == sample["frame_file"] - if sample["video_file"]: + else: + assert ds.attrs["frame_path"] is None + + if sample["video_file"] and with_video: assert ds.attrs["video_path"].name == sample["video_file"] + else: + assert ds.attrs["video_path"] is None # Test with an invalid file with pytest.raises(ValueError): diff --git a/tests/test_unit/test_save_poses.py b/tests/test_unit/test_save_poses.py index c4b830f1..592f0c9a 100644 --- a/tests/test_unit/test_save_poses.py +++ b/tests/test_unit/test_save_poses.py @@ -5,7 +5,7 @@ import pandas as pd import pytest import xarray as xr -from pytest import POSE_DATA_PATHS +from pytest import DATA_PATHS from movement.io import load_poses, save_poses @@ -53,6 +53,13 @@ class TestSavePoses: }, ] + invalid_poses_datasets_and_exceptions = [ + ("not_a_dataset", TypeError), + ("empty_dataset", ValueError), + ("missing_var_poses_dataset", ValueError), + ("missing_dim_poses_dataset", ValueError), + ] + @pytest.fixture(params=output_files) def output_file_params(self, request): """Return a dictionary containing parameters for testing saving @@ -63,28 +70,28 @@ def output_file_params(self, request): @pytest.mark.parametrize( "ds, expected_exception", [ - (np.array([1, 2, 3]), pytest.raises(ValueError)), # incorrect type + (np.array([1, 2, 3]), pytest.raises(TypeError)), # incorrect type ( load_poses.from_dlc_file( - POSE_DATA_PATHS.get("DLC_single-wasp.predictions.h5") + DATA_PATHS.get("DLC_single-wasp.predictions.h5") ), does_not_raise(), ), # valid dataset ( load_poses.from_dlc_file( - POSE_DATA_PATHS.get("DLC_two-mice.predictions.csv") + DATA_PATHS.get("DLC_two-mice.predictions.csv") ), does_not_raise(), ), # valid dataset ( load_poses.from_sleap_file( - POSE_DATA_PATHS.get("SLEAP_single-mouse_EPM.analysis.h5") + DATA_PATHS.get("SLEAP_single-mouse_EPM.analysis.h5") ), does_not_raise(), ), # valid dataset ( load_poses.from_sleap_file( - POSE_DATA_PATHS.get( + DATA_PATHS.get( "SLEAP_three-mice_Aeon_proofread.predictions.slp" ) ), @@ -92,18 +99,18 @@ def output_file_params(self, request): ), # valid dataset ( load_poses.from_lp_file( - POSE_DATA_PATHS.get("LP_mouse-face_AIND.predictions.csv") + DATA_PATHS.get("LP_mouse-face_AIND.predictions.csv") ), does_not_raise(), ), # valid dataset ], ) - def test_to_dlc_df(self, ds, expected_exception): + def test_to_dlc_style_df(self, ds, expected_exception): """Test that converting a valid/invalid xarray dataset to a DeepLabCut-style pandas DataFrame returns the expected result. """ with expected_exception as e: - df = save_poses.to_dlc_df(ds, split_individuals=False) + df = save_poses.to_dlc_style_df(ds, split_individuals=False) if e is None: # valid input assert isinstance(df, pd.DataFrame) assert isinstance(df.columns, pd.MultiIndex) @@ -126,15 +133,19 @@ def test_to_dlc_file_valid_dataset( file_path = val.get("file_path") if isinstance(val, dict) else val save_poses.to_dlc_file(valid_poses_dataset, file_path) + @pytest.mark.parametrize( + "invalid_poses_dataset, expected_exception", + invalid_poses_datasets_and_exceptions, + ) def test_to_dlc_file_invalid_dataset( - self, invalid_poses_dataset, tmp_path + self, invalid_poses_dataset, expected_exception, tmp_path, request ): """Test that saving an invalid pose dataset to a valid DeepLabCut-style file returns the appropriate errors. """ - with pytest.raises(ValueError): + with pytest.raises(expected_exception): save_poses.to_dlc_file( - invalid_poses_dataset, + request.getfixturevalue(invalid_poses_dataset), tmp_path / "test.h5", split_individuals=False, ) @@ -163,15 +174,15 @@ def test_auto_split_individuals(self, valid_poses_dataset, split_value): ], indirect=["valid_poses_dataset"], ) - def test_to_dlc_df_split_individuals( + def test_to_dlc_style_df_split_individuals( self, valid_poses_dataset, split_individuals, ): """Test that the `split_individuals` argument affects the behaviour - of the `to_dlc_df` function as expected. + of the `to_dlc_style_df` function as expected. """ - df = save_poses.to_dlc_df(valid_poses_dataset, split_individuals) + df = save_poses.to_dlc_style_df(valid_poses_dataset, split_individuals) # Get the names of the individuals in the dataset ind_names = valid_poses_dataset.individuals.values if split_individuals is False: @@ -252,13 +263,19 @@ def test_to_lp_file_valid_dataset( file_path = val.get("file_path") if isinstance(val, dict) else val save_poses.to_lp_file(valid_poses_dataset, file_path) - def test_to_lp_file_invalid_dataset(self, invalid_poses_dataset, tmp_path): + @pytest.mark.parametrize( + "invalid_poses_dataset, expected_exception", + invalid_poses_datasets_and_exceptions, + ) + def test_to_lp_file_invalid_dataset( + self, invalid_poses_dataset, expected_exception, tmp_path, request + ): """Test that saving an invalid pose dataset to a valid LightningPose-style file returns the appropriate errors. """ - with pytest.raises(ValueError): + with pytest.raises(expected_exception): save_poses.to_lp_file( - invalid_poses_dataset, + request.getfixturevalue(invalid_poses_dataset), tmp_path / "test.csv", ) @@ -274,15 +291,19 @@ def test_to_sleap_analysis_file_valid_dataset( file_path = val.get("file_path") if isinstance(val, dict) else val save_poses.to_sleap_analysis_file(valid_poses_dataset, file_path) + @pytest.mark.parametrize( + "invalid_poses_dataset, expected_exception", + invalid_poses_datasets_and_exceptions, + ) def test_to_sleap_analysis_file_invalid_dataset( - self, invalid_poses_dataset, new_h5_file + self, invalid_poses_dataset, expected_exception, new_h5_file, request ): """Test that saving an invalid pose dataset to a valid SLEAP-style file returns the appropriate errors. """ - with pytest.raises(ValueError): + with pytest.raises(expected_exception): save_poses.to_sleap_analysis_file( - invalid_poses_dataset, + request.getfixturevalue(invalid_poses_dataset), new_h5_file, ) diff --git a/tests/test_unit/test_validators.py b/tests/test_unit/test_validators.py deleted file mode 100644 index fc44f94f..00000000 --- a/tests/test_unit/test_validators.py +++ /dev/null @@ -1,233 +0,0 @@ -from contextlib import nullcontext as does_not_raise - -import numpy as np -import pytest - -from movement.io.validators import ( - ValidDeepLabCutCSV, - ValidFile, - ValidHDF5, - ValidPosesDataset, -) - - -class TestValidators: - """Test suite for the validators module.""" - - position_arrays = [ - { - "names": None, - "array_type": "multi_individual_array", - "individual_names_expected_exception": does_not_raise( - ["individual_0", "individual_1"] - ), - "keypoint_names_expected_exception": does_not_raise( - ["keypoint_0", "keypoint_1"] - ), - }, # valid input, will generate default names - { - "names": ["a", "b"], - "array_type": "multi_individual_array", - "individual_names_expected_exception": does_not_raise(["a", "b"]), - "keypoint_names_expected_exception": does_not_raise(["a", "b"]), - }, # valid input - { - "names": ("a", "b"), - "array_type": "multi_individual_array", - "individual_names_expected_exception": does_not_raise(["a", "b"]), - "keypoint_names_expected_exception": does_not_raise(["a", "b"]), - }, # valid input, will be converted to ["a", "b"] - { - "names": [1, 2], - "array_type": "multi_individual_array", - "individual_names_expected_exception": does_not_raise(["1", "2"]), - "keypoint_names_expected_exception": does_not_raise(["1", "2"]), - }, # valid input, will be converted to ["1", "2"] - { - "names": "a", - "array_type": "single_individual_array", - "individual_names_expected_exception": does_not_raise(["a"]), - "keypoint_names_expected_exception": pytest.raises(ValueError), - }, # single individual array with multiple keypoints - { - "names": "a", - "array_type": "single_keypoint_array", - "individual_names_expected_exception": pytest.raises(ValueError), - "keypoint_names_expected_exception": does_not_raise(["a"]), - }, # single keypoint array with multiple individuals - { - "names": 5, - "array_type": "multi_individual_array", - "individual_names_expected_exception": pytest.raises(ValueError), - "keypoint_names_expected_exception": pytest.raises(ValueError), - }, # invalid input - ] - - @pytest.fixture(params=position_arrays) - def position_array_params(self, request): - """Return a dictionary containing parameters for testing - position array keypoint and individual names. - """ - return request.param - - @pytest.mark.parametrize( - "invalid_input, expected_exception", - [ - ("unreadable_file", pytest.raises(PermissionError)), - ("unwriteable_file", pytest.raises(PermissionError)), - ("fake_h5_file", pytest.raises(FileExistsError)), - ("wrong_ext_file", pytest.raises(ValueError)), - ("nonexistent_file", pytest.raises(FileNotFoundError)), - ("directory", pytest.raises(IsADirectoryError)), - ], - ) - def test_file_validator_with_invalid_input( - self, invalid_input, expected_exception, request - ): - """Test that invalid files raise the appropriate errors.""" - invalid_dict = request.getfixturevalue(invalid_input) - with expected_exception: - ValidFile( - invalid_dict.get("file_path"), - expected_permission=invalid_dict.get("expected_permission"), - expected_suffix=invalid_dict.get("expected_suffix", []), - ) - - @pytest.mark.parametrize( - "invalid_input, expected_exception", - [ - ("h5_file_no_dataframe", pytest.raises(ValueError)), - ("fake_h5_file", pytest.raises(ValueError)), - ], - ) - def test_hdf5_validator_with_invalid_input( - self, invalid_input, expected_exception, request - ): - """Test that invalid HDF5 files raise the appropriate errors.""" - invalid_dict = request.getfixturevalue(invalid_input) - with expected_exception: - ValidHDF5( - invalid_dict.get("file_path"), - expected_datasets=invalid_dict.get("expected_datasets"), - ) - - @pytest.mark.parametrize( - "invalid_input, expected_exception", - [ - ("invalid_single_individual_csv_file", pytest.raises(ValueError)), - ("invalid_multi_individual_csv_file", pytest.raises(ValueError)), - ], - ) - def test_poses_csv_validator_with_invalid_input( - self, invalid_input, expected_exception, request - ): - """Test that invalid CSV files raise the appropriate errors.""" - file_path = request.getfixturevalue(invalid_input) - with expected_exception: - ValidDeepLabCutCSV(file_path) - - @pytest.mark.parametrize( - "invalid_position_array", - [ - None, # invalid, argument is non-optional - [1, 2, 3], # not an ndarray - np.zeros((10, 2, 3)), # not 4d - np.zeros((10, 2, 3, 4)), # last dim not 2 or 3 - ], - ) - def test_poses_dataset_validator_with_invalid_position_array( - self, invalid_position_array - ): - """Test that invalid position arrays raise the appropriate errors.""" - with pytest.raises(ValueError): - ValidPosesDataset(position_array=invalid_position_array) - - @pytest.mark.parametrize( - "confidence_array, expected_exception", - [ - ( - np.ones((10, 3, 2)), - pytest.raises(ValueError), - ), # will not match position_array shape - ( - [1, 2, 3], - pytest.raises(ValueError), - ), # not an ndarray, should raise ValueError - ( - None, - does_not_raise(), - ), # valid, should default to array of NaNs - ], - ) - def test_poses_dataset_validator_confidence_array( - self, - confidence_array, - expected_exception, - valid_position_array, - ): - """Test that invalid confidence arrays raise the appropriate errors.""" - with expected_exception: - poses = ValidPosesDataset( - position_array=valid_position_array("multi_individual_array"), - confidence_array=confidence_array, - ) - if confidence_array is None: - assert np.all(np.isnan(poses.confidence_array)) - - def test_poses_dataset_validator_keypoint_names( - self, position_array_params, valid_position_array - ): - """Test that invalid keypoint names raise the appropriate errors.""" - with position_array_params.get( - "keypoint_names_expected_exception" - ) as e: - poses = ValidPosesDataset( - position_array=valid_position_array( - position_array_params.get("array_type") - ), - keypoint_names=position_array_params.get("names"), - ) - assert poses.keypoint_names == e - - def test_poses_dataset_validator_individual_names( - self, position_array_params, valid_position_array - ): - """Test that invalid keypoint names raise the appropriate errors.""" - with position_array_params.get( - "individual_names_expected_exception" - ) as e: - poses = ValidPosesDataset( - position_array=valid_position_array( - position_array_params.get("array_type") - ), - individual_names=position_array_params.get("names"), - ) - assert poses.individual_names == e - - @pytest.mark.parametrize( - "source_software, expected_exception", - [ - (None, does_not_raise()), - ("SLEAP", does_not_raise()), - ("DeepLabCut", does_not_raise()), - ("LightningPose", pytest.raises(ValueError)), - ("fake_software", does_not_raise()), - (5, pytest.raises(TypeError)), # not a string - ], - ) - def test_poses_dataset_validator_source_software( - self, valid_position_array, source_software, expected_exception - ): - """Test that the source_software attribute is validated properly. - LightnigPose is incompatible with multi-individual arrays. - """ - with expected_exception: - ds = ValidPosesDataset( - position_array=valid_position_array("multi_individual_array"), - source_software=source_software, - ) - - if source_software is not None: - assert ds.source_software == source_software - else: - assert ds.source_software is None diff --git a/tests/test_unit/test_validators/__init__.py b/tests/test_unit/test_validators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_unit/test_validators/test_array_validators.py b/tests/test_unit/test_validators/test_array_validators.py new file mode 100644 index 00000000..a1a4412c --- /dev/null +++ b/tests/test_unit/test_validators/test_array_validators.py @@ -0,0 +1,56 @@ +import re +from contextlib import nullcontext as does_not_raise + +import pytest + +from movement.validators.arrays import validate_dims_coords + + +def expect_value_error_with_message(error_msg): + """Expect a ValueError with the specified error message.""" + return pytest.raises(ValueError, match=re.escape(error_msg)) + + +valid_cases = [ + ({"time": []}, does_not_raise()), + ({"time": [0, 1]}, does_not_raise()), + ({"space": ["x", "y"]}, does_not_raise()), + ({"time": [], "space": []}, does_not_raise()), + ({"time": [], "space": ["x", "y"]}, does_not_raise()), +] # Valid cases (no error) + +invalid_cases = [ + ( + {"spacetime": []}, + expect_value_error_with_message( + "Input data must contain ['spacetime'] as dimensions." + ), + ), + ( + {"time": [0, 100], "space": ["x", "y"]}, + expect_value_error_with_message( + "Input data must contain [100] in the 'time' coordinates." + ), + ), + ( + {"space": ["x", "y", "z"]}, + expect_value_error_with_message( + "Input data must contain ['z'] in the 'space' coordinates." + ), + ), +] # Invalid cases (raise ValueError) + + +@pytest.mark.parametrize( + "required_dims_coords, expected_exception", + valid_cases + invalid_cases, +) +def test_validate_dims_coords( + valid_poses_dataset_uniform_linear_motion, # fixture from conftest.py + required_dims_coords, + expected_exception, +): + """Test validate_dims_coords for both valid and invalid inputs.""" + position_array = valid_poses_dataset_uniform_linear_motion["position"] + with expected_exception: + validate_dims_coords(position_array, required_dims_coords) diff --git a/tests/test_unit/test_validators/test_datasets_validators.py b/tests/test_unit/test_validators/test_datasets_validators.py new file mode 100644 index 00000000..e41331f7 --- /dev/null +++ b/tests/test_unit/test_validators/test_datasets_validators.py @@ -0,0 +1,398 @@ +from contextlib import nullcontext as does_not_raise + +import numpy as np +import pytest + +from movement.validators.datasets import ValidBboxesDataset, ValidPosesDataset + +position_arrays = [ + { + "names": None, + "array_type": "multi_individual_array", + "individual_names_expected_exception": does_not_raise( + ["individual_0", "individual_1"] + ), + "keypoint_names_expected_exception": does_not_raise( + ["keypoint_0", "keypoint_1"] + ), + }, # valid input, will generate default names + { + "names": ["a", "b"], + "array_type": "multi_individual_array", + "individual_names_expected_exception": does_not_raise(["a", "b"]), + "keypoint_names_expected_exception": does_not_raise(["a", "b"]), + }, # valid input + { + "names": ("a", "b"), + "array_type": "multi_individual_array", + "individual_names_expected_exception": does_not_raise(["a", "b"]), + "keypoint_names_expected_exception": does_not_raise(["a", "b"]), + }, # valid input, will be converted to ["a", "b"] + { + "names": [1, 2], + "array_type": "multi_individual_array", + "individual_names_expected_exception": does_not_raise(["1", "2"]), + "keypoint_names_expected_exception": does_not_raise(["1", "2"]), + }, # valid input, will be converted to ["1", "2"] + { + "names": "a", + "array_type": "single_individual_array", + "individual_names_expected_exception": does_not_raise(["a"]), + "keypoint_names_expected_exception": pytest.raises(ValueError), + }, # single individual array with multiple keypoints + { + "names": "a", + "array_type": "single_keypoint_array", + "individual_names_expected_exception": pytest.raises(ValueError), + "keypoint_names_expected_exception": does_not_raise(["a"]), + }, # single keypoint array with multiple individuals + { + "names": 5, + "array_type": "multi_individual_array", + "individual_names_expected_exception": pytest.raises(ValueError), + "keypoint_names_expected_exception": pytest.raises(ValueError), + }, # invalid input +] + + +@pytest.fixture(params=position_arrays) +def position_array_params(request): + """Return a dictionary containing parameters for testing + position array keypoint and individual names. + """ + return request.param + + +# Fixtures bbox dataset +invalid_bboxes_arrays_and_expected_log = { + key: [ + ( + None, + f"Expected a numpy array, but got {type(None)}.", + ), # invalid, argument is non-optional + ( + [1, 2, 3], + f"Expected a numpy array, but got {type(list())}.", + ), # not an ndarray + ( + np.zeros((10, 2, 3)), + f"Expected '{key}_array' to have 2 spatial " + "coordinates, but got 3.", + ), # last dim not 2 + ] + for key in ["position", "shape"] +} + + +# Tests pose dataset +@pytest.mark.parametrize( + "invalid_position_array, log_message", + [ + ( + None, + f"Expected a numpy array, but got {type(None)}.", + ), # invalid, argument is non-optional + ( + [1, 2, 3], + f"Expected a numpy array, but got {type(list())}.", + ), # not an ndarray + ( + np.zeros((10, 2, 3)), + "Expected 'position_array' to have 4 dimensions, but got 3.", + ), # not 4d + ( + np.zeros((10, 2, 3, 4)), + "Expected 'position_array' to have 2 or 3 " + "spatial dimensions, but got 4.", + ), # last dim not 2 or 3 + ], +) +def test_poses_dataset_validator_with_invalid_position_array( + invalid_position_array, log_message +): + """Test that invalid position arrays raise the appropriate errors.""" + with pytest.raises(ValueError) as excinfo: + ValidPosesDataset(position_array=invalid_position_array) + assert str(excinfo.value) == log_message + + +@pytest.mark.parametrize( + "confidence_array, expected_exception", + [ + ( + np.ones((10, 3, 2)), + pytest.raises(ValueError), + ), # will not match position_array shape + ( + [1, 2, 3], + pytest.raises(ValueError), + ), # not an ndarray, should raise ValueError + ( + None, + does_not_raise(), + ), # valid, should default to array of NaNs + ], +) +def test_poses_dataset_validator_confidence_array( + confidence_array, + expected_exception, + valid_position_array, +): + """Test that invalid confidence arrays raise the appropriate errors.""" + with expected_exception: + poses = ValidPosesDataset( + position_array=valid_position_array("multi_individual_array"), + confidence_array=confidence_array, + ) + if confidence_array is None: + assert np.all(np.isnan(poses.confidence_array)) + + +def test_poses_dataset_validator_keypoint_names( + position_array_params, valid_position_array +): + """Test that invalid keypoint names raise the appropriate errors.""" + with position_array_params.get("keypoint_names_expected_exception") as e: + poses = ValidPosesDataset( + position_array=valid_position_array( + position_array_params.get("array_type") + ), + keypoint_names=position_array_params.get("names"), + ) + assert poses.keypoint_names == e + + +def test_poses_dataset_validator_individual_names( + position_array_params, valid_position_array +): + """Test that invalid keypoint names raise the appropriate errors.""" + with position_array_params.get("individual_names_expected_exception") as e: + poses = ValidPosesDataset( + position_array=valid_position_array( + position_array_params.get("array_type") + ), + individual_names=position_array_params.get("names"), + ) + assert poses.individual_names == e + + +@pytest.mark.parametrize( + "source_software, expected_exception", + [ + (None, does_not_raise()), + ("SLEAP", does_not_raise()), + ("DeepLabCut", does_not_raise()), + ("LightningPose", pytest.raises(ValueError)), + ("fake_software", does_not_raise()), + (5, pytest.raises(TypeError)), # not a string + ], +) +def test_poses_dataset_validator_source_software( + valid_position_array, source_software, expected_exception +): + """Test that the source_software attribute is validated properly. + LightnigPose is incompatible with multi-individual arrays. + """ + with expected_exception: + ds = ValidPosesDataset( + position_array=valid_position_array("multi_individual_array"), + source_software=source_software, + ) + + if source_software is not None: + assert ds.source_software == source_software + else: + assert ds.source_software is None + + +# Tests bboxes dataset +@pytest.mark.parametrize( + "invalid_position_array, log_message", + invalid_bboxes_arrays_and_expected_log["position"], +) +def test_bboxes_dataset_validator_with_invalid_position_array( + invalid_position_array, log_message, request +): + """Test that invalid centroid position arrays raise an error.""" + with pytest.raises(ValueError) as excinfo: + ValidBboxesDataset( + position_array=invalid_position_array, + shape_array=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["shape"], + individual_names=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["individual_names"], + ) + assert str(excinfo.value) == log_message + + +@pytest.mark.parametrize( + "invalid_shape_array, log_message", + invalid_bboxes_arrays_and_expected_log["shape"], +) +def test_bboxes_dataset_validator_with_invalid_shape_array( + invalid_shape_array, log_message, request +): + """Test that invalid shape arrays raise an error.""" + with pytest.raises(ValueError) as excinfo: + ValidBboxesDataset( + position_array=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["position"], + shape_array=invalid_shape_array, + individual_names=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["individual_names"], + ) + assert str(excinfo.value) == log_message + + +@pytest.mark.parametrize( + "list_individual_names, expected_exception, log_message", + [ + ( + None, + does_not_raise(), + "", + ), # valid, should default to unique IDs per frame + ( + [1, 2, 3], + pytest.raises(ValueError), + "Expected 'individual_names' to have length 2, " + f"but got {len([1, 2, 3])}.", + ), # length doesn't match position_array.shape[1] + # from valid_bboxes_arrays_all_zeros fixture + ( + ["id_1", "id_1"], + pytest.raises(ValueError), + "individual_names passed to the dataset are not unique. " + "There are 2 elements in the list, but " + "only 1 are unique.", + ), # some IDs are not unique. + # Note: length of individual_names list should match + # n_individuals in valid_bboxes_arrays_all_zeros fixture + ], +) +def test_bboxes_dataset_validator_individual_names( + list_individual_names, expected_exception, log_message, request +): + """Test individual_names inputs.""" + with expected_exception as excinfo: + ds = ValidBboxesDataset( + position_array=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["position"], + shape_array=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["shape"], + individual_names=list_individual_names, + ) + if list_individual_names is None: + # check IDs are unique per frame + assert len(ds.individual_names) == len(set(ds.individual_names)) + assert ds.position_array.shape[1] == len(ds.individual_names) + else: + assert str(excinfo.value) == log_message + + +@pytest.mark.parametrize( + "confidence_array, expected_exception, log_message", + [ + ( + np.ones((10, 3, 2)), + pytest.raises(ValueError), + f"Expected 'confidence_array' to have shape (10, 2), " + f"but got {np.ones((10, 3, 2)).shape}.", + ), # will not match shape of position_array in + # valid_bboxes_arrays_all_zeros fixture + ( + [1, 2, 3], + pytest.raises(ValueError), + f"Expected a numpy array, but got {type([1, 2, 3])}.", + ), # not an ndarray, should raise ValueError + ( + None, + does_not_raise(), + "", + ), # valid, should default to array of NaNs + ], +) +def test_bboxes_dataset_validator_confidence_array( + confidence_array, expected_exception, log_message, request +): + """Test that invalid confidence arrays raise the appropriate errors.""" + with expected_exception as excinfo: + ds = ValidBboxesDataset( + position_array=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["position"], + shape_array=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["shape"], + individual_names=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["individual_names"], + confidence_array=confidence_array, + ) + if confidence_array is None: + assert np.all( + np.isnan(ds.confidence_array) + ) # assert it is a NaN array + assert ( + ds.confidence_array.shape == ds.position_array.shape[:-1] + ) # assert shape matches position array + else: + assert str(excinfo.value) == log_message + + +@pytest.mark.parametrize( + "frame_array, expected_exception, log_message", + [ + ( + np.arange(10).reshape(-1, 2), + pytest.raises(ValueError), + "Expected 'frame_array' to have shape (10, 1), but got (5, 2).", + ), # frame_array should be a column vector + ( + [1, 2, 3], + pytest.raises(ValueError), + f"Expected a numpy array, but got {type(list())}.", + ), # not an ndarray, should raise ValueError + ( + np.array([1, 2, 3, 4, 6, 7, 8, 9, 10, 11]).reshape(-1, 1), + pytest.raises(ValueError), + "Frame numbers in frame_array are not continuous.", + ), # frame numbers are not continuous + ( + None, + does_not_raise(), + "", + ), # valid, should return an array of frame numbers starting from 0 + ], +) +def test_bboxes_dataset_validator_frame_array( + frame_array, expected_exception, log_message, request +): + """Test that invalid frame arrays raise the appropriate errors.""" + with expected_exception as excinfo: + ds = ValidBboxesDataset( + position_array=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["position"], + shape_array=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["shape"], + individual_names=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["individual_names"], + frame_array=frame_array, + ) + + if frame_array is None: + n_frames = ds.position_array.shape[0] + default_frame_array = np.arange(n_frames).reshape(-1, 1) + assert np.array_equal(ds.frame_array, default_frame_array) + assert ds.frame_array.shape == (ds.position_array.shape[0], 1) + else: + assert str(excinfo.value) == log_message diff --git a/tests/test_unit/test_validators/test_files_validators.py b/tests/test_unit/test_validators/test_files_validators.py new file mode 100644 index 00000000..b9bc345c --- /dev/null +++ b/tests/test_unit/test_validators/test_files_validators.py @@ -0,0 +1,168 @@ +import pytest + +from movement.validators.files import ( + ValidDeepLabCutCSV, + ValidFile, + ValidHDF5, + ValidVIATracksCSV, +) + + +@pytest.mark.parametrize( + "invalid_input, expected_exception", + [ + ("unreadable_file", pytest.raises(PermissionError)), + ("unwriteable_file", pytest.raises(PermissionError)), + ("fake_h5_file", pytest.raises(FileExistsError)), + ("wrong_ext_file", pytest.raises(ValueError)), + ("nonexistent_file", pytest.raises(FileNotFoundError)), + ("directory", pytest.raises(IsADirectoryError)), + ], +) +def test_file_validator_with_invalid_input( + invalid_input, expected_exception, request +): + """Test that invalid files raise the appropriate errors.""" + invalid_dict = request.getfixturevalue(invalid_input) + with expected_exception: + ValidFile( + invalid_dict.get("file_path"), + expected_permission=invalid_dict.get("expected_permission"), + expected_suffix=invalid_dict.get("expected_suffix", []), + ) + + +@pytest.mark.parametrize( + "invalid_input, expected_exception", + [ + ("h5_file_no_dataframe", pytest.raises(ValueError)), + ("fake_h5_file", pytest.raises(ValueError)), + ], +) +def test_hdf5_validator_with_invalid_input( + invalid_input, expected_exception, request +): + """Test that invalid HDF5 files raise the appropriate errors.""" + invalid_dict = request.getfixturevalue(invalid_input) + with expected_exception: + ValidHDF5( + invalid_dict.get("file_path"), + expected_datasets=invalid_dict.get("expected_datasets"), + ) + + +@pytest.mark.parametrize( + "invalid_input, expected_exception", + [ + ("invalid_single_individual_csv_file", pytest.raises(ValueError)), + ("invalid_multi_individual_csv_file", pytest.raises(ValueError)), + ], +) +def test_deeplabcut_csv_validator_with_invalid_input( + invalid_input, expected_exception, request +): + """Test that invalid CSV files raise the appropriate errors.""" + file_path = request.getfixturevalue(invalid_input) + with expected_exception: + ValidDeepLabCutCSV(file_path) + + +@pytest.mark.parametrize( + "invalid_input, log_message", + [ + ( + "via_tracks_csv_with_invalid_header", + ".csv header row does not match the known format for " + "VIA tracks .csv files. " + "Expected " + "['filename', 'file_size', 'file_attributes', " + "'region_count', 'region_id', 'region_shape_attributes', " + "'region_attributes'] " + "but got ['filename', 'file_size', 'file_attributes'].", + ), + ( + "frame_number_in_file_attribute_not_integer", + "04.09.2023-04-Right_RE_test_frame_A.png (row 0): " + "'frame' file attribute cannot be cast as an integer. " + "Please review the file attributes: " + "{'clip': 123, 'frame': 'FOO'}.", + ), + ( + "frame_number_in_filename_wrong_pattern", + "04.09.2023-04-Right_RE_test_frame_1.png (row 0): " + "a frame number could not be extracted from the filename. " + "If included in the filename, the frame number is " + "expected as a zero-padded integer between an " + "underscore '_' and the file extension " + "(e.g. img_00234.png).", + ), + ( + "more_frame_numbers_than_filenames", + "The number of unique frame numbers does not match the number " + "of unique image files. Please review the VIA tracks .csv file " + "and ensure a unique frame number is defined for each file. ", + ), + ( + "less_frame_numbers_than_filenames", + "The number of unique frame numbers does not match the number " + "of unique image files. Please review the VIA tracks .csv file " + "and ensure a unique frame number is defined for each file. ", + ), + ( + "region_shape_attribute_not_rect", + "04.09.2023-04-Right_RE_test_frame_01.png (row 0): " + "bounding box shape must be 'rect' (rectangular) " + "but instead got 'circle'.", + ), + ( + "region_shape_attribute_missing_x", + "04.09.2023-04-Right_RE_test_frame_01.png (row 0): " + "at least one bounding box shape parameter is missing. " + "Expected 'x', 'y', 'width', 'height' to exist as " + "'region_shape_attributes', but got " + "'['name', 'y', 'width', 'height']'.", + ), + ( + "region_attribute_missing_track", + "04.09.2023-04-Right_RE_test_frame_01.png (row 0): " + "bounding box does not have a 'track' attribute defined " + "under 'region_attributes'. " + "Please review the VIA tracks .csv file.", + ), + ( + "track_id_not_castable_as_int", + "04.09.2023-04-Right_RE_test_frame_01.png (row 0): " + "the track ID for the bounding box cannot be cast " + "as an integer. " + "Please review the VIA tracks .csv file.", + ), + ( + "track_ids_not_unique_per_frame", + "04.09.2023-04-Right_RE_test_frame_01.png: " + "multiple bounding boxes in this file have the same track ID. " + "Please review the VIA tracks .csv file.", + ), + ], +) +def test_via_tracks_csv_validator_with_invalid_input( + invalid_input, log_message, request +): + """Test that invalid VIA tracks .csv files raise the appropriate errors. + + Errors to check: + - error if .csv header is wrong + - error if frame number is not defined in the file + (frame number extracted either from the filename or from attributes) + - error if extracted frame numbers are not 1-based integers + - error if region_shape_attributes "name" is not "rect" + - error if not all region_attributes have key "track" + (i.e., all regions must have an ID assigned) + - error if IDs are unique per frame + (i.e., bboxes IDs must exist only once per frame) + - error if bboxes IDs are not 1-based integers + """ + file_path = request.getfixturevalue(invalid_input) + with pytest.raises(ValueError) as excinfo: + ValidVIATracksCSV(file_path) + + assert str(excinfo.value) == log_message diff --git a/tests/test_unit/test_vector.py b/tests/test_unit/test_vector.py index 88dd85b5..8787a468 100644 --- a/tests/test_unit/test_vector.py +++ b/tests/test_unit/test_vector.py @@ -121,3 +121,70 @@ def test_pol2cart(self, ds, expected_exception, request): with expected_exception: result = vector.pol2cart(ds.pol) xr.testing.assert_allclose(result, ds.cart) + + @pytest.mark.parametrize( + "ds, expected_exception", + [ + ("cart_pol_dataset", does_not_raise()), + ("cart_pol_dataset_with_nan", does_not_raise()), + ("cart_pol_dataset_missing_cart_dim", pytest.raises(ValueError)), + ( + "cart_pol_dataset_missing_cart_coords", + pytest.raises(ValueError), + ), + ], + ) + def test_compute_norm(self, ds, expected_exception, request): + """Test vector norm computation with known values.""" + ds = request.getfixturevalue(ds) + with expected_exception: + # validate the norm computation + result = vector.compute_norm(ds.cart) + expected = np.sqrt( + ds.cart.sel(space="x") ** 2 + ds.cart.sel(space="y") ** 2 + ) + xr.testing.assert_allclose(result, expected) + + # result should be the same from Cartesian and polar coordinates + xr.testing.assert_allclose(result, vector.compute_norm(ds.pol)) + + # The result should only contain the time dimension. + assert result.dims == ("time",) + + @pytest.mark.parametrize( + "ds, expected_exception", + [ + ("cart_pol_dataset", does_not_raise()), + ("cart_pol_dataset_with_nan", does_not_raise()), + ("cart_pol_dataset_missing_cart_dim", pytest.raises(ValueError)), + ], + ) + def test_convert_to_unit(self, ds, expected_exception, request): + """Test conversion to unit vectors (normalisation).""" + ds = request.getfixturevalue(ds) + with expected_exception: + # normalise both the Cartesian and the polar data to unit vectors + unit_cart = vector.convert_to_unit(ds.cart) + unit_pol = vector.convert_to_unit(ds.pol) + # they should yield the same result, just in different coordinates + xr.testing.assert_allclose(unit_cart, vector.pol2cart(unit_pol)) + xr.testing.assert_allclose(unit_pol, vector.cart2pol(unit_cart)) + + # since we established that polar vs Cartesian unit vectors are + # equivalent, it's enough to do other assertions on either one + + # the normalised data should have the same dimensions as the input + assert unit_cart.dims == ds.cart.dims + + # unit vector should be NaN if the input vector was null or NaN + is_null_vec = (ds.cart == 0).all("space") # null vec: x=0, y=0 + is_nan_vec = ds.cart.isnull().any("space") # any NaN in x or y + expected_nan_idxs = is_null_vec | is_nan_vec + assert unit_cart.where(expected_nan_idxs).isnull().all() + + # For non-NaN unit vectors in polar coordinates, the rho values + # should be 1 and the phi values should be the same as the input + expected_unit_pol = ds.pol.copy() + expected_unit_pol.loc[{"space_pol": "rho"}] = 1 + expected_unit_pol = expected_unit_pol.where(~expected_nan_idxs) + xr.testing.assert_allclose(unit_pol, expected_unit_pol)