From f7f3b48265f80b9a62059a5cbb55572b773e2ee7 Mon Sep 17 00:00:00 2001 From: Niko Sirmpilatze Date: Thu, 24 Oct 2024 11:50:51 +0100 Subject: [PATCH] Drop the `.move` accessor (#322) * replace accessor with MovementDataset dataclass * moved pre-save validations inside save_poses module * deleted accessor code and associated tests * define dataset structure in modular classes * updated stale docstring for _validate_dataset() * remove mentions of the accessor from the getting started guide * dropped accessor use in examples * ignore linkcheck for opensource licenses * Revert "ignore linkcheck for opensource licenses" This reverts commit c8f3498f2c911ac79c08cc909746f2a7335a4699. * use ds.sizes instead of ds.dims to suppress warning * Add references * remove movement_dataset.py module --------- Co-authored-by: lochhh --- .../getting_started/movement_dataset.md | 58 ++-- examples/compute_kinematics.py | 45 ++- examples/filter_and_interpolate.py | 102 ++---- examples/smooth.py | 82 ++--- movement/__init__.py | 1 - movement/io/load_bboxes.py | 3 +- movement/io/load_poses.py | 3 +- movement/io/save_poses.py | 20 +- movement/move_accessor.py | 302 ------------------ movement/validators/datasets.py | 9 +- tests/conftest.py | 8 +- tests/test_integration/test_filtering.py | 76 ++--- .../test_kinematics_vector_transform.py | 5 +- tests/test_unit/test_load_bboxes.py | 4 +- tests/test_unit/test_load_poses.py | 4 +- tests/test_unit/test_move_accessor.py | 128 -------- tests/test_unit/test_save_poses.py | 6 +- .../test_datasets_validators.py | 2 +- 18 files changed, 168 insertions(+), 690 deletions(-) delete mode 100644 movement/move_accessor.py delete mode 100644 tests/test_unit/test_move_accessor.py diff --git a/docs/source/getting_started/movement_dataset.md b/docs/source/getting_started/movement_dataset.md index 6b2ef5e9..7e80a81c 100644 --- a/docs/source/getting_started/movement_dataset.md +++ b/docs/source/getting_started/movement_dataset.md @@ -195,7 +195,7 @@ For example, you can: [data aggregation and broadcasting](xarray:user-guide/computation.html), and - use `xarray`'s built-in [plotting methods](xarray:user-guide/plotting.html). -As an example, here's how you can use the `sel` method to select subsets of +As an example, here's how you can use {meth}`xarray.Dataset.sel` to select subsets of data: ```python @@ -223,44 +223,41 @@ position = ds.position.sel( ) # the output is a data array ``` -### Accessing movement-specific functionality +### Modifying movement datasets -`movement` extends `xarray`'s functionality with a number of convenience -methods that are specific to `movement` datasets. These `movement`-specific methods are accessed using the -`move` keyword. +Datasets can be modified by adding new **data variables** and **attributes**, +or updating existing ones. -For example, to compute the velocity and acceleration vectors for all individuals and keypoints across time, we provide the `move.compute_velocity` and `move.compute_acceleration` methods: +Let's imagine we want to compute the instantaneous velocity of all tracked +points and store the results within the same dataset, for convenience. ```python -velocity = ds.move.compute_velocity() -acceleration = ds.move.compute_acceleration() -``` +from movement.analysis.kinematics import compute_velocity -The `movement`-specific functionalities are implemented in the -{class}`movement.move_accessor.MovementDataset` class, which is an [accessor](https://docs.xarray.dev/en/stable/internals/extending-xarray.html) to the -underlying {class}`xarray.Dataset` object. Defining a custom accessor is convenient -to avoid conflicts with `xarray`'s built-in methods. +# compute velocity from position +velocity = compute_velocity(ds.position) +# add it to the dataset as a new data variable +ds["velocity"] = velocity -### Modifying movement datasets +# we could have also done both steps in a single line +ds["velocity"] = compute_velocity(ds.position) -The `velocity` and `acceleration` produced in the above example are {class}`xarray.DataArray` objects, with the same **dimensions** as the -original `position` **data variable**. +# we can now access velocity like any other data variable +ds.velocity +``` -In some cases, you may wish to -add these or other new **data variables** to the `movement` dataset for -convenience. This can be done by simply assigning them to the dataset -with an appropriate name: +The output of {func}`movement.analysis.kinematics.compute_velocity` is an {class}`xarray.DataArray` object, +with the same **dimensions** as the original `position` **data variable**, +so adding it to the existing `ds` makes sense and works seamlessly. -```python -ds["velocity"] = velocity -ds["acceleration"] = acceleration +We can also update existing **data variables** in-place, using {meth}`xarray.Dataset.update`. For example, if we wanted to update the `position` +and `velocity` arrays in our dataset, we could do: -# we can now access these using dot notation on the dataset -ds.velocity -ds.acceleration +```python +ds.update({"position": position_filtered, "velocity": velocity_filtered}) ``` -Custom **attributes** can also be added to the dataset: +Custom **attributes** can be added to the dataset with: ```python ds.attrs["my_custom_attribute"] = "my_custom_value" @@ -268,10 +265,3 @@ ds.attrs["my_custom_attribute"] = "my_custom_value" # we can now access this value using dot notation on the dataset ds.my_custom_attribute ``` - -We can also update existing **data variables** in-place, using the `update()` method. For example, if we wanted to update the `position` -and `velocity` arrays in our dataset, we could do: - -```python -ds.update({"position": position_filtered, "velocity": velocity_filtered}) -``` diff --git a/examples/compute_kinematics.py b/examples/compute_kinematics.py index b3fefd4a..d9107696 100644 --- a/examples/compute_kinematics.py +++ b/examples/compute_kinematics.py @@ -117,27 +117,22 @@ # %% # Compute displacement # --------------------- +# The :mod:`movement.analysis.kinematics` module provides functions to compute +# various kinematic quantities, +# such as displacement, velocity, and acceleration. # We can start off by computing the distance travelled by the mice along -# their trajectories. -# For this, we can use the ``compute_displacement`` method of the -# ``move`` accessor. -displacement = ds.move.compute_displacement() +# their trajectories: -# %% -# This method will return a data array equivalent to the ``position`` one, -# but holding displacement data along the ``space`` axis, rather than -# position data. - -# %% -# Notice that we could also compute the displacement (and all the other -# kinematic variables) using the :mod:`movement.analysis.kinematics` module: - -# %% import movement.analysis.kinematics as kin -displacement_kin = kin.compute_displacement(position) +displacement = kin.compute_displacement(position) # %% +# The :func:`movement.analysis.kinematics.compute_displacement` +# function will return a data array equivalent to the ``position`` one, +# but holding displacement data along the ``space`` axis, rather than +# position data. +# # The ``displacement`` data array holds, for a given individual and keypoint # at timestep ``t``, the vector that goes from its previous position at time # ``t-1`` to its current position at time ``t``. @@ -271,13 +266,14 @@ # ---------------- # We can easily compute the velocity vectors for all individuals in our data # array: -velocity = ds.move.compute_velocity() +velocity = kin.compute_velocity(position) # %% -# The ``velocity`` method will return a data array equivalent to the -# ``position`` one, but holding velocity data along the ``space`` axis, rather -# than position data. Notice how ``xarray`` nicely deals with the different -# individuals and spatial dimensions for us! ✨ +# The :func:`movement.analysis.kinematics.compute_velocity` +# function will return a data array equivalent to +# the ``position`` one, but holding velocity data along the ``space`` axis, +# rather than position data. Notice how ``xarray`` nicely deals with the +# different individuals and spatial dimensions for us! ✨ # %% # We can plot the components of the velocity vector against time @@ -350,8 +346,9 @@ # %% # Compute acceleration # --------------------- -# We can compute the acceleration of the data with an equivalent method: -accel = ds.move.compute_acceleration() +# Let's now compute the acceleration for all individuals in our data +# array: +accel = kin.compute_acceleration(position) # %% # and plot of the components of the acceleration vector ``ax``, ``ay`` per @@ -375,8 +372,8 @@ fig.tight_layout() # %% -# The can also represent the magnitude (norm) of the acceleration vector -# for each individual: +# We can also compute and visualise the magnitude (norm) of the +# acceleration vector for each individual: fig, axes = plt.subplots(3, 1, sharex=True, sharey=True) for mouse_name, ax in zip(accel.individuals.values, axes, strict=False): # compute magnitude of the acceleration vector for one mouse diff --git a/examples/filter_and_interpolate.py b/examples/filter_and_interpolate.py index 71384ca7..fa18582b 100644 --- a/examples/filter_and_interpolate.py +++ b/examples/filter_and_interpolate.py @@ -9,6 +9,8 @@ # Imports # ------- from movement import sample_data +from movement.analysis.kinematics import compute_velocity +from movement.filtering import filter_by_confidence, interpolate_over_time # %% # Load a sample dataset @@ -73,35 +75,19 @@ # %% # Filter out points with low confidence # ------------------------------------- -# Using the -# :meth:`filter_by_confidence()\ -# ` -# method of the ``move`` accessor, -# we can filter out points with confidence scores below a certain threshold. -# The default ``threshold=0.6`` will be used when ``threshold`` is not -# provided. -# This method will also report the number of NaN values in the dataset before -# and after the filtering operation by default (``print_report=True``). +# Using the :func:`movement.filtering.filter_by_confidence` function from the +# :mod:`movement.filtering` module, we can filter out points with confidence +# scores below a certain threshold. This function takes ``position`` and +# ``confidence`` as required arguments, and accepts an optional ``threshold`` +# parameter, which defaults to ``threshold=0.6`` unless specified otherwise. +# The function will also report the number of NaN values in the dataset before +# and after the filtering operation by default, but you can disable this +# by passing ``print_report=False``. +# # We will use :meth:`xarray.Dataset.update` to update ``ds`` in-place # with the filtered ``position``. -ds.update({"position": ds.move.filter_by_confidence()}) - -# %% -# .. note:: -# The ``move`` accessor :meth:`filter_by_confidence()\ -# ` -# method is a convenience method that applies -# :func:`movement.filtering.filter_by_confidence`, -# which takes ``position`` and ``confidence`` as arguments. -# The equivalent function call using the -# :mod:`movement.filtering` module would be: -# -# .. code-block:: python -# -# from movement.filtering import filter_by_confidence -# -# ds.update({"position": filter_by_confidence(position, confidence)}) +ds.update({"position": filter_by_confidence(ds.position, ds.confidence)}) # %% # We can see that the filtering operation has introduced NaN values in the @@ -120,36 +106,16 @@ # %% # Interpolate over missing values # ------------------------------- -# Using the -# :meth:`interpolate_over_time()\ -# ` -# method of the ``move`` accessor, -# we can interpolate over the gaps we've introduced in the pose tracks. +# Using the :func:`movement.filtering.interpolate_over_time` function from the +# :mod:`movement.filtering` module, we can interpolate over gaps +# we've introduced in the pose tracks. # Here we use the default linear interpolation method (``method=linear``) # and interpolate over gaps of 40 frames or less (``max_gap=40``). # The default ``max_gap=None`` would interpolate over all gaps, regardless of # their length, but this should be used with caution as it can introduce # spurious data. The ``print_report`` argument acts as described above. -ds.update({"position": ds.move.interpolate_over_time(max_gap=40)}) - -# %% -# .. note:: -# The ``move`` accessor :meth:`interpolate_over_time()\ -# ` -# is also a convenience method that applies -# :func:`movement.filtering.interpolate_over_time` -# to the ``position`` data variable. -# The equivalent function call using the -# :mod:`movement.filtering` module would be: -# -# .. code-block:: python -# -# from movement.filtering import interpolate_over_time -# -# ds.update({"position": interpolate_over_time( -# position_filtered, max_gap=40 -# )}) +ds.update({"position": interpolate_over_time(ds.position, max_gap=40)}) # %% # We see that all NaN values have disappeared, meaning that all gaps were @@ -176,27 +142,25 @@ # %% # Filtering multiple data variables # --------------------------------- -# All :mod:`movement.filtering` functions are available via the -# ``move`` accessor. These ``move`` accessor methods operate on the -# ``position`` data variable in the dataset ``ds`` by default. -# There is also an additional argument ``data_vars`` that allows us to -# specify which data variables in ``ds`` to filter. -# When multiple data variable names are specified in ``data_vars``, -# the method will return a dictionary with the data variable names as keys -# and the filtered DataArrays as values, otherwise it will return a single -# DataArray that is the filtered data. -# This is useful when we want to apply the same filtering operation to +# We can also apply the same filtering operation to # multiple data variables in ``ds`` at the same time. # # For instance, to filter both ``position`` and ``velocity`` data variables -# in ``ds``, based on the confidence scores, we can specify -# ``data_vars=["position", "velocity"]`` in the method call. -# As the filtered data variables are returned as a dictionary, we can once -# again use :meth:`xarray.Dataset.update` to update ``ds`` in-place +# in ``ds``, based on the confidence scores, we can specify a dictionary +# with the data variable names as keys and the corresponding filtered +# DataArrays as values. Then we can once again use +# :meth:`xarray.Dataset.update` to update ``ds`` in-place # with the filtered data variables. -ds["velocity"] = ds.move.compute_velocity() -filtered_data_dict = ds.move.filter_by_confidence( - data_vars=["position", "velocity"] -) -ds.update(filtered_data_dict) +# Add velocity data variable to the dataset +ds["velocity"] = compute_velocity(ds.position) + +# Create a dictionary mapping data variable names to filtered DataArrays +# We disable report printing for brevity +update_dict = { + var: filter_by_confidence(ds[var], ds.confidence, print_report=False) + for var in ["position", "velocity"] +} + +# Use the dictionary to update the dataset in-place +ds.update(update_dict) diff --git a/examples/smooth.py b/examples/smooth.py index 316d9444..f87ac411 100644 --- a/examples/smooth.py +++ b/examples/smooth.py @@ -12,6 +12,11 @@ from scipy.signal import welch from movement import sample_data +from movement.filtering import ( + interpolate_over_time, + median_filter, + savgol_filter, +) # %% # Load a sample dataset @@ -33,8 +38,8 @@ # %% # Define a plotting function # -------------------------- -# Let's define a plotting function to help us visualise the effects smoothing -# both in the time and frequency domains. +# Let's define a plotting function to help us visualise the effects of +# smoothing both in the time and frequency domains. # The function takes as inputs two datasets containing raw and smooth data # respectively, and plots the position time series and power spectral density # (PSD) for a given individual and keypoint. The function also allows you to @@ -77,9 +82,8 @@ def plot_raw_and_smooth_timeseries_and_psd( ) # interpolate data to remove NaNs in the PSD calculation - pos_interp = ds.sel(**selection).move.interpolate_over_time( - print_report=False - ) + pos_interp = interpolate_over_time(pos, print_report=False) + # compute and plot the PSD freq, psd = welch(pos_interp, fs=ds.fps, nperseg=256) ax[1].semilogy( @@ -108,12 +112,9 @@ def plot_raw_and_smooth_timeseries_and_psd( # %% # Smoothing with a median filter # ------------------------------ -# Using the -# :meth:`median_filter()\ -# ` -# method of the ``move`` accessor, -# we apply a rolling window median filter over a 0.1-second window -# (4 frames) to the wasp dataset. +# Using the :func:`movement.filtering.median_filter` function on the +# ``position`` data variable, we can apply a rolling window median filter +# over a 0.1-second window (4 frames) to the wasp dataset. # As the ``window`` parameter is defined in *number of observations*, # we can simply multiply the desired time window by the frame rate # of the video. We will also create a copy of the dataset to avoid @@ -121,23 +122,7 @@ def plot_raw_and_smooth_timeseries_and_psd( window = int(0.1 * ds_wasp.fps) ds_wasp_smooth = ds_wasp.copy() -ds_wasp_smooth.update({"position": ds_wasp_smooth.move.median_filter(window)}) - -# %% -# .. note:: -# The ``move`` accessor :meth:`median_filter()\ -# ` -# method is a convenience method that applies -# :func:`movement.filtering.median_filter` -# to the ``position`` data variable. -# The equivalent function call using the -# :mod:`movement.filtering` module would be: -# -# .. code-block:: python -# -# from movement.filtering import median_filter -# -# ds_wasp_smooth.update({"position": median_filter(position, window)}) +ds_wasp_smooth.update({"position": median_filter(ds_wasp.position, window)}) # %% # We see from the printed report that the dataset has no missing values @@ -181,9 +166,7 @@ def plot_raw_and_smooth_timeseries_and_psd( window = int(0.1 * ds_mouse.fps) ds_mouse_smooth = ds_mouse.copy() -ds_mouse_smooth.update( - {"position": ds_mouse_smooth.move.median_filter(window)} -) +ds_mouse_smooth.update({"position": median_filter(ds_mouse.position, window)}) # %% # The report informs us that the raw data contains NaN values, most of which @@ -199,7 +182,7 @@ def plot_raw_and_smooth_timeseries_and_psd( # window are sufficient for the median to be calculated. Let's try this. ds_mouse_smooth.update( - {"position": ds_mouse.move.median_filter(window, min_periods=2)} + {"position": median_filter(ds_mouse.position, window, min_periods=2)} ) # %% @@ -222,7 +205,7 @@ def plot_raw_and_smooth_timeseries_and_psd( window = int(2 * ds_mouse.fps) ds_mouse_smooth.update( - {"position": ds_mouse.move.median_filter(window, min_periods=2)} + {"position": median_filter(ds_mouse.position, window, min_periods=2)} ) # %% @@ -248,13 +231,9 @@ def plot_raw_and_smooth_timeseries_and_psd( # %% # Smoothing with a Savitzky-Golay filter # -------------------------------------- -# Here we use the -# :meth:`savgol_filter()\ -# ` -# method of the ``move`` accessor, which is a convenience method that applies -# :func:`movement.filtering.savgol_filter` -# (a wrapper around :func:`scipy.signal.savgol_filter`), -# to the ``position`` data variable. +# Here we apply the :func:`movement.filtering.savgol_filter` function +# (a wrapper around :func:`scipy.signal.savgol_filter`), to the ``position`` +# data variable. # The Savitzky-Golay filter is a polynomial smoothing filter that can be # applied to time series data on a rolling window basis. # A polynomial with a degree specified by ``polyorder`` is applied to each @@ -268,7 +247,7 @@ def plot_raw_and_smooth_timeseries_and_psd( # to be used as the ``window`` size. window = int(0.2 * ds_mouse.fps) -ds_mouse_smooth.update({"position": ds_mouse.move.savgol_filter(window)}) +ds_mouse_smooth.update({"position": savgol_filter(ds_mouse.position, window)}) # %% # We see that the number of NaN values has increased after filtering. This is @@ -289,7 +268,7 @@ def plot_raw_and_smooth_timeseries_and_psd( # Now let's apply the same Savitzky-Golay filter to the wasp dataset. window = int(0.2 * ds_wasp.fps) -ds_wasp_smooth.update({"position": ds_wasp.move.savgol_filter(window)}) +ds_wasp_smooth.update({"position": savgol_filter(ds_wasp.position, window)}) # %% plot_raw_and_smooth_timeseries_and_psd( @@ -315,27 +294,24 @@ def plot_raw_and_smooth_timeseries_and_psd( # with a larger ``window`` to further smooth the data. # Between the two filters, we can interpolate over small gaps to avoid the # excessive proliferation of NaN values. Let's try this on the mouse dataset. -# First, we will apply the median filter. +# First, we will apply the median filter. window = int(0.1 * ds_mouse.fps) ds_mouse_smooth.update( - {"position": ds_mouse.move.median_filter(window, min_periods=2)} + {"position": median_filter(ds_mouse.position, window, min_periods=2)} ) -# %% -# Next, let's linearly interpolate over gaps smaller than 1 second (30 frames). - +# Next, let's linearly interpolate over gaps smaller +# than 1 second (30 frames). ds_mouse_smooth.update( - {"position": ds_mouse_smooth.move.interpolate_over_time(max_gap=30)} + {"position": interpolate_over_time(ds_mouse_smooth.position, max_gap=30)} ) -# %% -# Finally, let's apply the Savitzky-Golay filter over a 0.4-second window -# (12 frames). - +# Finally, let's apply the Savitzky-Golay filter +# over a 0.4-second window (12 frames). window = int(0.4 * ds_mouse.fps) ds_mouse_smooth.update( - {"position": ds_mouse_smooth.move.savgol_filter(window)} + {"position": savgol_filter(ds_mouse_smooth.position, window)} ) # %% diff --git a/movement/__init__.py b/movement/__init__.py index bc9115b1..bf5d4a2d 100644 --- a/movement/__init__.py +++ b/movement/__init__.py @@ -1,7 +1,6 @@ from importlib.metadata import PackageNotFoundError, version from movement.utils.logging import configure_logging -from movement.move_accessor import MovementDataset try: __version__ = version("movement") diff --git a/movement/io/load_bboxes.py b/movement/io/load_bboxes.py index 8550a2e8..3e1b0e0d 100644 --- a/movement/io/load_bboxes.py +++ b/movement/io/load_bboxes.py @@ -11,7 +11,6 @@ import pandas as pd import xarray as xr -from movement import MovementDataset from movement.utils.logging import log_error from movement.validators.datasets import ValidBboxesDataset from movement.validators.files import ValidFile, ValidVIATracksCSV @@ -631,7 +630,7 @@ def _ds_from_valid_data(data: ValidBboxesDataset) -> xr.Dataset: # Convert data to an xarray.Dataset # with dimensions ('time', 'individuals', 'space') - DIM_NAMES = MovementDataset.dim_names["bboxes"] + DIM_NAMES = ValidBboxesDataset.DIM_NAMES n_space = data.position_array.shape[-1] return xr.Dataset( data_vars={ diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index 2b1a25d8..f425d8a1 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -11,7 +11,6 @@ from sleap_io.io.slp import read_labels from sleap_io.model.labels import Labels -from movement import MovementDataset from movement.utils.logging import log_error, log_warning from movement.validators.datasets import ValidPosesDataset from movement.validators.files import ValidDeepLabCutCSV, ValidFile, ValidHDF5 @@ -654,7 +653,7 @@ def _ds_from_valid_data(data: ValidPosesDataset) -> xr.Dataset: time_coords = time_coords / data.fps time_unit = "seconds" - DIM_NAMES = MovementDataset.dim_names["poses"] + DIM_NAMES = ValidPosesDataset.DIM_NAMES # Convert data to an xarray.Dataset return xr.Dataset( data_vars={ diff --git a/movement/io/save_poses.py b/movement/io/save_poses.py index bc2c0e1c..c47d28f1 100644 --- a/movement/io/save_poses.py +++ b/movement/io/save_poses.py @@ -10,6 +10,7 @@ import xarray as xr from movement.utils.logging import log_error +from movement.validators.datasets import ValidPosesDataset from movement.validators.files import ValidFile logger = logging.getLogger(__name__) @@ -424,12 +425,25 @@ def _validate_dataset(ds: xr.Dataset) -> None: Raises ------ + TypeError + If the input is not an xarray Dataset. ValueError - If `ds` is not an a valid ``movement`` dataset. + If the dataset is missing required data variables or dimensions. """ if not isinstance(ds, xr.Dataset): raise log_error( - ValueError, f"Expected an xarray Dataset, but got {type(ds)}." + TypeError, f"Expected an xarray Dataset, but got {type(ds)}." ) - ds.move.validate() # validate the dataset + + missing_vars = set(ValidPosesDataset.VAR_NAMES) - set(ds.data_vars) + if missing_vars: + raise ValueError( + f"Missing required data variables: {sorted(missing_vars)}" + ) # sort for a reproducible error message + + missing_dims = set(ValidPosesDataset.DIM_NAMES) - set(ds.dims) + if missing_dims: + raise ValueError( + f"Missing required dimensions: {sorted(missing_dims)}" + ) # sort for a reproducible error message diff --git a/movement/move_accessor.py b/movement/move_accessor.py deleted file mode 100644 index 64b17651..00000000 --- a/movement/move_accessor.py +++ /dev/null @@ -1,302 +0,0 @@ -"""Accessor for extending :class:`xarray.Dataset` objects.""" - -import logging -from typing import ClassVar - -import xarray as xr - -from movement import filtering -from movement.analysis import kinematics -from movement.utils.logging import log_error -from movement.validators.datasets import ValidBboxesDataset, ValidPosesDataset - -logger = logging.getLogger(__name__) - -# Preserve the attributes (metadata) of xarray objects after operations -xr.set_options(keep_attrs=True) - - -@xr.register_dataset_accessor("move") -class MovementDataset: - """An :class:`xarray.Dataset` accessor for ``movement`` data. - - A ``movement`` dataset is an :class:`xarray.Dataset` with a specific - structure to represent pose tracks or bounding boxes data, - associated confidence scores and relevant metadata. - - Methods/properties that extend the standard ``xarray`` functionality are - defined in this class. To avoid conflicts with ``xarray``'s namespace, - ``movement``-specific methods are accessed using the ``move`` keyword, - for example ``ds.move.validate()`` (see [1]_ for more details). - - Attributes - ---------- - dim_names : dict - A dictionary with the names of the expected dimensions in the dataset, - for each dataset type (``"poses"`` or ``"bboxes"``). - var_names : dict - A dictionary with the expected data variables in the dataset, for each - dataset type (``"poses"`` or ``"bboxes"``). - - References - ---------- - .. [1] https://docs.xarray.dev/en/stable/internals/extending-xarray.html - - """ - - # Set class attributes for expected dimensions and data variables - dim_names: ClassVar[dict] = { - "poses": ("time", "individuals", "keypoints", "space"), - "bboxes": ("time", "individuals", "space"), - } - var_names: ClassVar[dict] = { - "poses": ("position", "confidence"), - "bboxes": ("position", "shape", "confidence"), - } - - def __init__(self, ds: xr.Dataset): - """Initialize the MovementDataset.""" - self._obj = ds - # Set instance attributes based on dataset type - self.dim_names_instance = self.dim_names[self._obj.ds_type] - self.var_names_instance = self.var_names[self._obj.ds_type] - - def __getattr__(self, name: str) -> xr.DataArray: - """Forward requested but undefined attributes to relevant modules. - - This method currently only forwards kinematic property computation - and filtering operations to the respective functions in - :mod:`movement.analysis.kinematics` and - :mod:`movement.filtering`. - - Parameters - ---------- - name : str - The name of the attribute to get. - - Returns - ------- - xarray.DataArray - The computed attribute value. - - Raises - ------ - AttributeError - If the attribute does not exist. - - """ - - def method(*args, **kwargs): - if hasattr(kinematics, name): - return self.kinematics_wrapper(name, *args, **kwargs) - elif hasattr(filtering, name): - return self.filtering_wrapper(name, *args, **kwargs) - else: - error_msg = ( - f"'{self.__class__.__name__}' object has " - f"no attribute '{name}'" - ) - raise log_error(AttributeError, error_msg) - - return method - - def kinematics_wrapper( - self, fn_name: str, *args, **kwargs - ) -> xr.DataArray: - """Provide convenience method for computing kinematic properties. - - This method forwards kinematic property computation - to the respective functions in :mod:`movement.analysis.kinematics`. - - Parameters - ---------- - fn_name : str - The name of the kinematics function to call. - args : tuple - Positional arguments to pass to the function. - kwargs : dict - Keyword arguments to pass to the function. - - Returns - ------- - xarray.DataArray - The computed kinematics attribute value. - - Raises - ------ - RuntimeError - If the requested function fails to execute. - - Examples - -------- - Compute ``displacement`` based on the ``position`` data variable - in the Dataset ``ds`` and store the result in ``ds``. - - >>> ds["displacement"] = ds.move.compute_displacement() - - Compute ``velocity`` based on the ``position`` data variable in - the Dataset ``ds`` and store the result in ``ds``. - - >>> ds["velocity"] = ds.move.compute_velocity() - - Compute ``acceleration`` based on the ``position`` data variable - in the Dataset ``ds`` and store the result in ``ds``. - - >>> ds["acceleration"] = ds.move.compute_acceleration() - - """ - try: - return getattr(kinematics, fn_name)( - self._obj.position, *args, **kwargs - ) - except Exception as e: - error_msg = ( - f"Failed to evoke '{fn_name}' via 'move' accessor. {str(e)}" - ) - raise log_error(RuntimeError, error_msg) from e - - def filtering_wrapper( - self, fn_name: str, *args, data_vars: list[str] | None = None, **kwargs - ) -> xr.DataArray | dict[str, xr.DataArray]: - """Provide convenience method for filtering data variables. - - This method forwards filtering and/or smoothing to the respective - functions in :mod:`movement.filtering`. The data variables to - filter can be specified in ``data_vars``. If ``data_vars`` is not - specified, the ``position`` data variable is selected by default. - - Parameters - ---------- - fn_name : str - The name of the filtering function to call. - args : tuple - Positional arguments to pass to the function. - data_vars : list[str] | None - The data variables to apply filtering. If ``None``, the - ``position`` data variable will be passed by default. - kwargs : dict - Keyword arguments to pass to the function. - - Returns - ------- - xarray.DataArray | dict[str, xarray.DataArray] - The filtered data variable or a dictionary of filtered data - variables, if multiple data variables are specified. - - Raises - ------ - RuntimeError - If the requested function fails to execute. - - Examples - -------- - Filter the ``position`` data variable to drop points with - ``confidence`` below 0.7 and store the result back into the - Dataset ``ds``. - Since ``data_vars`` is not supplied, the filter will be applied to - the ``position`` data variable by default. - - >>> ds["position"] = ds.move.filter_by_confidence(threshold=0.7) - - Apply a median filter to the ``position`` data variable and - store this back into the Dataset ``ds``. - - >>> ds["position"] = ds.move.median_filter(window=3) - - Apply a Savitzky-Golay filter to both the ``position`` and - ``velocity`` data variables and store these back into the - Dataset ``ds``. ``filtered_data`` is a dictionary, where the keys - are the data variable names and the values are the filtered - DataArrays. - - >>> filtered_data = ds.move.savgol_filter( - ... window=3, data_vars=["position", "velocity"] - ... ) - >>> ds.update(filtered_data) - - """ - ds = self._obj - if data_vars is None: # Default to filter on position - data_vars = ["position"] - if fn_name == "filter_by_confidence": - # Add confidence to kwargs - kwargs["confidence"] = ds.confidence - try: - result = { - data_var: getattr(filtering, fn_name)( - ds[data_var], *args, **kwargs - ) - for data_var in data_vars - } - # Return DataArray if result only has one key - if len(result) == 1: - return result[list(result.keys())[0]] - return result - except Exception as e: - error_msg = ( - f"Failed to evoke '{fn_name}' via 'move' accessor. {str(e)}" - ) - raise log_error(RuntimeError, error_msg) from e - - def validate(self) -> None: - """Validate the dataset. - - This method checks if the dataset contains the expected dimensions, - data variables, and metadata attributes. It also ensures that the - dataset contains valid poses or bounding boxes data. - - Raises - ------ - ValueError - If the dataset is missing required dimensions, data variables, - or contains invalid poses or bounding boxes data. - - """ - fps = self._obj.attrs.get("fps", None) - source_software = self._obj.attrs.get("source_software", None) - try: - self._validate_dims() - self._validate_data_vars() - if self._obj.ds_type == "poses": - ValidPosesDataset( - position_array=self._obj["position"].values, - confidence_array=self._obj["confidence"].values, - individual_names=self._obj.coords["individuals"].values, - keypoint_names=self._obj.coords["keypoints"].values, - fps=fps, - source_software=source_software, - ) - elif self._obj.ds_type == "bboxes": - # Define frame_array. - # Recover from time axis in seconds if necessary. - frame_array = self._obj.coords["time"].values.reshape(-1, 1) - if self._obj.attrs["time_unit"] == "seconds": - frame_array *= fps - ValidBboxesDataset( - position_array=self._obj["position"].values, - shape_array=self._obj["shape"].values, - confidence_array=self._obj["confidence"].values, - individual_names=self._obj.coords["individuals"].values, - frame_array=frame_array, - fps=fps, - source_software=source_software, - ) - except Exception as e: - error_msg = ( - f"The dataset does not contain valid {self._obj.ds_type}. {e}" - ) - raise log_error(ValueError, error_msg) from e - - def _validate_dims(self) -> None: - missing_dims = set(self.dim_names_instance) - set(self._obj.dims) - if missing_dims: - raise ValueError( - f"Missing required dimensions: {sorted(missing_dims)}" - ) # sort for a reproducible error message - - def _validate_data_vars(self) -> None: - missing_vars = set(self.var_names_instance) - set(self._obj.data_vars) - if missing_vars: - raise ValueError( - f"Missing required data variables: {sorted(missing_vars)}" - ) # sort for a reproducible error message diff --git a/movement/validators/datasets.py b/movement/validators/datasets.py index fd31246d..99a68c10 100644 --- a/movement/validators/datasets.py +++ b/movement/validators/datasets.py @@ -1,7 +1,7 @@ """``attrs`` classes for validating data structures.""" from collections.abc import Iterable -from typing import Any +from typing import Any, ClassVar import attrs import numpy as np @@ -142,6 +142,10 @@ class ValidPosesDataset: validator=validators.optional(validators.instance_of(str)), ) + # Class variables + DIM_NAMES: ClassVar[tuple] = ("time", "individuals", "keypoints", "space") + VAR_NAMES: ClassVar[tuple] = ("position", "confidence") + # Add validators @position_array.validator def _validate_position_array(self, attribute, value): @@ -293,6 +297,9 @@ class ValidBboxesDataset: validator=validators.optional(validators.instance_of(str)), ) + DIM_NAMES: ClassVar[tuple] = ("time", "individuals", "space") + VAR_NAMES: ClassVar[tuple] = ("position", "shape", "confidence") + # Validators @position_array.validator @shape_array.validator diff --git a/tests/conftest.py b/tests/conftest.py index 272e5eaa..6da9a598 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,9 +11,9 @@ import pytest import xarray as xr -from movement import MovementDataset from movement.sample_data import fetch_dataset_paths, list_datasets from movement.utils.logging import configure_logging +from movement.validators.datasets import ValidBboxesDataset, ValidPosesDataset def pytest_configure(): @@ -292,7 +292,7 @@ def valid_bboxes_dataset( """Return a valid bboxes dataset for two individuals moving in uniform linear motion, with 5 frames with low confidence values and time in frames. """ - dim_names = MovementDataset.dim_names["bboxes"] + dim_names = ValidBboxesDataset.DIM_NAMES position_array = valid_bboxes_arrays["position"] shape_array = valid_bboxes_arrays["shape"] @@ -376,7 +376,7 @@ def _valid_position_array(array_type): @pytest.fixture def valid_poses_dataset(valid_position_array, request): """Return a valid pose tracks dataset.""" - dim_names = MovementDataset.dim_names["poses"] + dim_names = ValidPosesDataset.DIM_NAMES # create a multi_individual_array by default unless overridden via param try: array_format = request.param @@ -490,7 +490,7 @@ def valid_poses_dataset_uniform_linear_motion( """Return a valid poses dataset for two individuals moving in uniform linear motion, with 5 frames with low confidence values and time in frames. """ - dim_names = MovementDataset.dim_names["poses"] + dim_names = ValidPosesDataset.DIM_NAMES position_array = valid_poses_array_uniform_linear_motion["position"] confidence_array = valid_poses_array_uniform_linear_motion["confidence"] diff --git a/tests/test_integration/test_filtering.py b/tests/test_integration/test_filtering.py index cba430f0..e3e87901 100644 --- a/tests/test_integration/test_filtering.py +++ b/tests/test_integration/test_filtering.py @@ -1,8 +1,10 @@ -from contextlib import nullcontext as does_not_raise - import pytest -import xarray as xr +from movement.filtering import ( + filter_by_confidence, + interpolate_over_time, + savgol_filter, +) from movement.io import load_poses from movement.sample_data import fetch_dataset_paths @@ -14,7 +16,6 @@ def sample_dataset(): "poses" ] ds = load_poses.from_dlc_file(ds_path) - return ds @@ -31,13 +32,18 @@ def test_nan_propagation_through_filters(sample_dataset, window, helpers): # Check filter position by confidence creates correct number of NaNs sample_dataset.update( - {"position": sample_dataset.move.filter_by_confidence()} + { + "position": filter_by_confidence( + sample_dataset.position, + sample_dataset.confidence, + ) + } ) n_total_nans_input = helpers.count_nans(sample_dataset.position) assert ( n_total_nans_input - == n_low_confidence_kpts * sample_dataset.dims["space"] + == n_low_confidence_kpts * sample_dataset.sizes["space"] ) # Compute maximum expected increase in NaNs due to filtering @@ -48,7 +54,11 @@ def test_nan_propagation_through_filters(sample_dataset, window, helpers): # Apply savgol filter and check that number of NaNs is within threshold sample_dataset.update( - {"position": sample_dataset.move.savgol_filter(window, polyorder=2)} + { + "position": savgol_filter( + sample_dataset.position, window, polyorder=2 + ) + } ) n_total_nans_savgol = helpers.count_nans(sample_dataset.position) @@ -60,56 +70,6 @@ def test_nan_propagation_through_filters(sample_dataset, window, helpers): # Interpolate data (without max_gap) and check it eliminates all NaNs sample_dataset.update( - {"position": sample_dataset.move.interpolate_over_time()} + {"position": interpolate_over_time(sample_dataset.position)} ) assert helpers.count_nans(sample_dataset.position) == 0 - - -@pytest.mark.parametrize( - "method", - [ - "filter_by_confidence", - "interpolate_over_time", - "median_filter", - "savgol_filter", - ], -) -@pytest.mark.parametrize( - "data_vars, expected_exception", - [ - (None, does_not_raise(xr.DataArray)), - (["position", "velocity"], does_not_raise(dict)), - (["vlocity"], pytest.raises(RuntimeError)), # Does not exist - ], -) -def test_accessor_filter_method( - sample_dataset, method, data_vars, expected_exception -): - """Test that filtering methods in the ``move`` accessor - return the expected data type and structure, and the - expected ``log`` attribute containing the filtering method - applied, if valid data variables are passed, otherwise - raise an exception. - """ - # Compute velocity - sample_dataset["velocity"] = sample_dataset.move.compute_velocity() - - with expected_exception as expected_type: - if method in ["median_filter", "savgol_filter"]: - # supply required "window" argument - result = getattr(sample_dataset.move, method)( - data_vars=data_vars, window=3 - ) - else: - result = getattr(sample_dataset.move, method)(data_vars=data_vars) - assert isinstance(result, expected_type) - if isinstance(result, xr.DataArray): - assert hasattr(result, "log") - assert result.log[0]["operation"] == method - elif isinstance(result, dict): - assert set(result.keys()) == set(data_vars) - assert all(hasattr(value, "log") for value in result.values()) - assert all( - value.log[0]["operation"] == method - for value in result.values() - ) diff --git a/tests/test_integration/test_kinematics_vector_transform.py b/tests/test_integration/test_kinematics_vector_transform.py index 63ecc2e4..5fa1b91c 100644 --- a/tests/test_integration/test_kinematics_vector_transform.py +++ b/tests/test_integration/test_kinematics_vector_transform.py @@ -4,6 +4,7 @@ import pytest import xarray as xr +import movement.analysis.kinematics as kin from movement.utils import vector @@ -64,7 +65,9 @@ def test_cart2pol_transform_on_kinematics( with various kinematic properties. """ ds = request.getfixturevalue(valid_dataset_uniform_linear_motion) - kinematic_array_cart = getattr(ds.move, f"compute_{kinematic_variable}")() + kinematic_array_cart = getattr(kin, f"compute_{kinematic_variable}")( + ds.position + ) kinematic_array_pol = vector.cart2pol(kinematic_array_cart) # Build expected data array diff --git a/tests/test_unit/test_load_bboxes.py b/tests/test_unit/test_load_bboxes.py index 474e6118..2f80459d 100644 --- a/tests/test_unit/test_load_bboxes.py +++ b/tests/test_unit/test_load_bboxes.py @@ -8,8 +8,8 @@ import pytest import xarray as xr -from movement import MovementDataset from movement.io import load_bboxes +from movement.validators.datasets import ValidBboxesDataset @pytest.fixture() @@ -127,7 +127,7 @@ def assert_dataset( assert dataset.confidence.shape == dataset.position.shape[:-1] # Check the dims and coords - DIM_NAMES = MovementDataset.dim_names["bboxes"] + DIM_NAMES = ValidBboxesDataset.DIM_NAMES assert all([i in dataset.dims for i in DIM_NAMES]) for d, dim in enumerate(DIM_NAMES[1:]): assert dataset.sizes[dim] == dataset.position.shape[d + 1] diff --git a/tests/test_unit/test_load_poses.py b/tests/test_unit/test_load_poses.py index 8fedcdcb..77990a42 100644 --- a/tests/test_unit/test_load_poses.py +++ b/tests/test_unit/test_load_poses.py @@ -8,8 +8,8 @@ from sleap_io.io.slp import read_labels, write_labels from sleap_io.model.labels import LabeledFrame, Labels -from movement import MovementDataset from movement.io import load_poses +from movement.validators.datasets import ValidPosesDataset class TestLoadPoses: @@ -78,7 +78,7 @@ def assert_dataset( assert dataset.position.ndim == 4 assert dataset.confidence.shape == dataset.position.shape[:-1] # Check the dims and coords - DIM_NAMES = MovementDataset.dim_names["poses"] + DIM_NAMES = ValidPosesDataset.DIM_NAMES assert all([i in dataset.dims for i in DIM_NAMES]) for d, dim in enumerate(DIM_NAMES[1:]): assert dataset.sizes[dim] == dataset.position.shape[d + 1] diff --git a/tests/test_unit/test_move_accessor.py b/tests/test_unit/test_move_accessor.py deleted file mode 100644 index b87942e4..00000000 --- a/tests/test_unit/test_move_accessor.py +++ /dev/null @@ -1,128 +0,0 @@ -from contextlib import nullcontext as does_not_raise - -import pytest -import xarray as xr - - -@pytest.mark.parametrize( - "valid_dataset", ("valid_poses_dataset", "valid_bboxes_dataset") -) -def test_compute_kinematics_with_valid_dataset( - valid_dataset, kinematic_property, request -): - """Test that computing a kinematic property of a valid - poses or bounding boxes dataset via accessor methods returns - an instance of xr.DataArray. - """ - valid_input_dataset = request.getfixturevalue(valid_dataset) - - result = getattr( - valid_input_dataset.move, f"compute_{kinematic_property}" - )() - assert isinstance(result, xr.DataArray) - - -@pytest.mark.parametrize( - "invalid_dataset", - ( - "not_a_dataset", - "empty_dataset", - "missing_var_poses_dataset", - "missing_var_bboxes_dataset", - "missing_dim_poses_dataset", - "missing_dim_bboxes_dataset", - ), -) -def test_compute_kinematics_with_invalid_dataset( - invalid_dataset, kinematic_property, request -): - """Test that computing a kinematic property of an invalid - poses or bounding boxes dataset via accessor methods raises - the appropriate error. - """ - invalid_dataset = request.getfixturevalue(invalid_dataset) - expected_exception = ( - RuntimeError - if isinstance(invalid_dataset, xr.Dataset) - else AttributeError - ) - with pytest.raises(expected_exception): - getattr(invalid_dataset.move, f"compute_{kinematic_property}")() - - -@pytest.mark.parametrize( - "method", ["compute_invalid_property", "do_something"] -) -@pytest.mark.parametrize( - "valid_dataset", ("valid_poses_dataset", "valid_bboxes_dataset") -) -def test_invalid_move_method_call(valid_dataset, method, request): - """Test that invalid accessor method calls raise an AttributeError.""" - valid_input_dataset = request.getfixturevalue(valid_dataset) - with pytest.raises(AttributeError): - getattr(valid_input_dataset.move, method)() - - -@pytest.mark.parametrize( - "input_dataset, expected_exception, expected_patterns", - ( - ( - "valid_poses_dataset", - does_not_raise(), - [], - ), - ( - "valid_bboxes_dataset", - does_not_raise(), - [], - ), - ( - "valid_bboxes_dataset_in_seconds", - does_not_raise(), - [], - ), - ( - "missing_dim_poses_dataset", - pytest.raises(ValueError), - ["Missing required dimensions:", "['time']"], - ), - ( - "missing_dim_bboxes_dataset", - pytest.raises(ValueError), - ["Missing required dimensions:", "['time']"], - ), - ( - "missing_two_dims_bboxes_dataset", - pytest.raises(ValueError), - ["Missing required dimensions:", "['space', 'time']"], - ), - ( - "missing_var_poses_dataset", - pytest.raises(ValueError), - ["Missing required data variables:", "['position']"], - ), - ( - "missing_var_bboxes_dataset", - pytest.raises(ValueError), - ["Missing required data variables:", "['position']"], - ), - ( - "missing_two_vars_bboxes_dataset", - pytest.raises(ValueError), - ["Missing required data variables:", "['position', 'shape']"], - ), - ), -) -def test_move_validate( - input_dataset, expected_exception, expected_patterns, request -): - """Test the validate method returns the expected message.""" - input_dataset = request.getfixturevalue(input_dataset) - - with expected_exception as excinfo: - input_dataset.move.validate() - - if expected_patterns: - error_message = str(excinfo.value) - assert input_dataset.ds_type in error_message - assert all([pattern in error_message for pattern in expected_patterns]) diff --git a/tests/test_unit/test_save_poses.py b/tests/test_unit/test_save_poses.py index 0f606e31..592f0c9a 100644 --- a/tests/test_unit/test_save_poses.py +++ b/tests/test_unit/test_save_poses.py @@ -54,8 +54,8 @@ class TestSavePoses: ] invalid_poses_datasets_and_exceptions = [ - ("not_a_dataset", ValueError), - ("empty_dataset", RuntimeError), + ("not_a_dataset", TypeError), + ("empty_dataset", ValueError), ("missing_var_poses_dataset", ValueError), ("missing_dim_poses_dataset", ValueError), ] @@ -70,7 +70,7 @@ def output_file_params(self, request): @pytest.mark.parametrize( "ds, expected_exception", [ - (np.array([1, 2, 3]), pytest.raises(ValueError)), # incorrect type + (np.array([1, 2, 3]), pytest.raises(TypeError)), # incorrect type ( load_poses.from_dlc_file( DATA_PATHS.get("DLC_single-wasp.predictions.h5") diff --git a/tests/test_unit/test_validators/test_datasets_validators.py b/tests/test_unit/test_validators/test_datasets_validators.py index 493f1d46..e41331f7 100644 --- a/tests/test_unit/test_validators/test_datasets_validators.py +++ b/tests/test_unit/test_validators/test_datasets_validators.py @@ -352,7 +352,7 @@ def test_bboxes_dataset_validator_confidence_array( ( np.arange(10).reshape(-1, 2), pytest.raises(ValueError), - "Expected 'frame_array' to have shape (10, 1), " "but got (5, 2).", + "Expected 'frame_array' to have shape (10, 1), but got (5, 2).", ), # frame_array should be a column vector ( [1, 2, 3],