Skip to content

Commit

Permalink
Refactor cdist and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lochhh committed Sep 12, 2024
1 parent 6b43ecb commit a44615e
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 112 deletions.
39 changes: 31 additions & 8 deletions movement/analysis/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,8 @@ def cdist(
core_dim = "individuals" if dim == "keypoints" else "keypoints"
elem1 = getattr(a, dim).item()
elem2 = getattr(b, dim).item()
if a.coords.get(core_dim) is None:
a = a.assign_coords({core_dim: "temp"})
if b.coords.get(core_dim) is None:
b = b.assign_coords({core_dim: "temp"})
if a.coords[core_dim].ndim == 0:
a = a.expand_dims(core_dim).transpose("time", "space", core_dim)
if b.coords[core_dim].ndim == 0:
b = b.expand_dims(core_dim).transpose("time", "space", core_dim)
a = _validate_core_dimension(a, core_dim)
b = _validate_core_dimension(b, core_dim)
result = xr.apply_ufunc(
_cdist,
a,
Expand Down Expand Up @@ -610,6 +604,35 @@ def _compute_pairwise_distances(
return pairwise_distances


def _validate_core_dimension(
data: xr.DataArray, core_dim: str
) -> xr.DataArray:
"""Validate the input data contains the required core dimension.
This function ensures the input data contains the ``core_dim``
required when applying :func:`scipy.spatial.distance.cdist` to
the input data, by adding a temporary dimension if necessary.
Parameters
----------
data : xarray.DataArray
The input data to validate.
core_dim : str
The core dimension to validate.
Returns
-------
xarray.DataArray
The input data with the core dimension validated.
"""
if data.coords.get(core_dim) is None:
data = data.assign_coords({core_dim: "temp_dim"})
if data.coords[core_dim].ndim == 0:
data = data.expand_dims(core_dim).transpose("time", "space", core_dim)
return data


def _validate_time_dimension(data: xr.DataArray) -> None:
"""Validate the input data contains a ``time`` dimension.
Expand Down
34 changes: 0 additions & 34 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,40 +574,6 @@ def kinematic_property(request):
return request.param


@pytest.fixture
def pairwise_distances_dataset():
"""Return a minimal poses dataset with 3 individuals
and 3 keypoints for pairwise distances computation.
"""
time = np.arange(2)
space = ["x", "y"]
individuals = ["ind1", "ind2", "ind3"]
keypoints = ["key1", "key2", "key3"]
data = np.array(
[
[
[[1, 1], [0, 0], [1, 0]],
[[1, 0], [1, 1], [0, 0]],
[[0, 0], [1, 0], [1, 1]],
],
[
[[3, 6], [1, 4], [0, 4]],
[[0, 4], [3, 6], [1, 4]],
[[1, 4], [0, 4], [3, 6]],
],
]
)
return xr.Dataset(
data_vars={
"position": xr.DataArray(
data,
coords=[time, individuals, keypoints, space],
dims=["time", "individuals", "keypoints", "space"],
)
}
)


# ---------------- VIA tracks CSV file fixtures ----------------------------
@pytest.fixture
def via_tracks_csv_with_invalid_header(tmp_path):
Expand Down
143 changes: 73 additions & 70 deletions tests/test_unit/test_kinematics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import itertools
from contextlib import nullcontext as does_not_raise

import numpy as np
import pytest
Expand Down Expand Up @@ -188,52 +187,42 @@ def test_approximate_derivative_with_invalid_order(order):


@pytest.mark.parametrize(
"dim, pairs, expected_data",
"dim, expected_data",
[
(
"individuals",
("ind1", "ind2"),
np.array(
[
[
[1.0, 0.0, np.sqrt(2)],
[1.0, np.sqrt(2), 0.0],
[0.0, 1.0, 1.0],
[1.0, np.sqrt(2), 0.0],
[1.0, 2.0, np.sqrt(2)],
],
[
[np.sqrt(13), 0.0, np.sqrt(8)],
[1.0, np.sqrt(8), 0.0],
[0.0, np.sqrt(13), 1.0],
[2.0, np.sqrt(5), 1.0],
[3.0, np.sqrt(10), 2.0],
[np.sqrt(5), np.sqrt(8), np.sqrt(2)],
],
]
),
),
(
"keypoints",
("key1", "key2"),
np.array(
[
[
[np.sqrt(2), 0.0, 1.0],
[1.0, 1.0, 0.0],
[0.0, np.sqrt(2), 1.0],
],
[
[np.sqrt(8), 0.0, np.sqrt(13)],
[1.0, np.sqrt(13), 0.0],
[0.0, np.sqrt(8), 1.0],
],
]
[[[1.0, 1.0], [1.0, 1.0]], [[1.0, np.sqrt(5)], [3.0, 1.0]]]
),
),
],
)
def test_cdist_with_known_values(
dim, pairs, expected_data, pairwise_distances_dataset
dim, expected_data, valid_poses_dataset_uniform_linear_motion
):
"""Test the computation of pairwise distances with known values."""
core_dim = "keypoints" if dim == "individuals" else "individuals"
input_dataarray = pairwise_distances_dataset.position
input_dataarray = valid_poses_dataset_uniform_linear_motion.position.sel(
time=slice(0, 1)
) # Use only the first two frames for simplicity
pairs = input_dataarray[dim].values[:2]
expected = xr.DataArray(
expected_data,
coords=[
Expand All @@ -252,62 +241,76 @@ def test_cdist_with_known_values(
)


@pytest.mark.parametrize(
"valid_dataset",
[
"valid_poses_dataset_uniform_linear_motion",
"valid_bboxes_dataset",
],
)
@pytest.mark.parametrize(
"selection_fn",
[
# individuals dim is scalar,
# poses: multiple keypoints
# bboxes: missing keypoints dim
# e.g. comparing 2 individuals from the same data array
lambda position: (
position.sel(individuals="ind1"),
position.sel(individuals="ind2"),
), # individuals dim is scalar
position.isel(individuals=0),
position.isel(individuals=1),
),
# individuals dim is 1D
# poses: multiple keypoints
# bboxes: missing keypoints dim
# e.g. comparing 2 single-individual data arrays
lambda position: (
position.where(
position.individuals == "ind1", drop=True
position.individuals == position.individuals[0], drop=True
).squeeze(),
position.where(
position.individuals == "ind2", drop=True
position.individuals == position.individuals[1], drop=True
).squeeze(),
), # individuals dim is 1D
lambda position: (
position.sel(individuals="ind1", keypoints="key1"),
position.sel(individuals="ind2", keypoints="key1"),
), # both individuals and keypoints dims are scalar
),
# both individuals and keypoints dims are scalar (poses only)
# e.g. comparing 2 individuals from the same data array,
# at the same keypoint
lambda position: (
position.where(position.keypoints == "key1", drop=True).sel(
individuals="ind1"
),
position.where(position.keypoints == "key1", drop=True).sel(
individuals="ind2"
),
), # keypoints dim is 1D
position.isel(individuals=0, keypoints=0),
position.isel(individuals=1, keypoints=0),
),
# individuals dim is scalar, keypoints dim is 1D (poses only)
# e.g. comparing 2 single-individual, single-keypoint data arrays
lambda position: (
position.drop_sel(keypoints=position.keypoints.values[1:])
.squeeze(drop=True)
.sel(individuals="ind1"),
position.drop_sel(keypoints=position.keypoints.values[1:])
.squeeze(drop=True)
.sel(individuals="ind2"),
), # missing core dim
position.where(
position.keypoints == position.keypoints[0], drop=True
).isel(individuals=0),
position.where(
position.keypoints == position.keypoints[0], drop=True
).isel(individuals=1),
),
],
ids=[
"dim_has_ndim_0",
"dim_has_ndim_1",
"core_dim_has_ndim_0",
"core_dim_has_ndim_1",
"missing_core_dim",
],
)
def test_cdist_with_single_dim_inputs(
pairwise_distances_dataset, selection_fn
):
"""Test that the computation of pairwise distances
works regardless of whether the input DataArrays have
```dim``` and ```core_dim``` being either scalar (ndim=0)
or 1D (ndim=1), or if ``core_dim`` is missing.
def test_cdist_with_single_dim_inputs(valid_dataset, selection_fn, request):
"""Test that the pairwise distances data array is successfully
returned regardless of whether the input DataArrays have
``dim`` ("individuals") and ``core_dim`` ("keypoints")
being either scalar (ndim=0) or 1D (ndim=1),
or if ``core_dim`` is missing.
"""
position = pairwise_distances_dataset.position
a, b = selection_fn(position)
with does_not_raise():
kinematics.cdist(a, b, "individuals")
if request.node.callspec.id not in [
"core_dim_has_ndim_0-valid_bboxes_dataset",
"core_dim_has_ndim_1-valid_bboxes_dataset",
]: # Skip tests with keypoints dim for bboxes
valid_dataset = request.getfixturevalue(valid_dataset)
position = valid_dataset.position
a, b = selection_fn(position)
assert isinstance(kinematics.cdist(a, b, "individuals"), xr.DataArray)


def expected_pairwise_distances(pairs, input_ds, dim):
Expand Down Expand Up @@ -335,27 +338,27 @@ def expected_pairwise_distances(pairs, input_ds, dim):
@pytest.mark.parametrize(
"dim, pairs",
[
("individuals", {"ind1": ["ind2"]}), # list input
("individuals", {"ind1": "ind2"}), # string input
("individuals", {"ind1": ["ind2", "ind3"], "ind2": "ind3"}),
("individuals", {"id_1": ["id_2"]}), # list input
("individuals", {"id_1": "id_2"}), # string input
("individuals", {"id_1": ["id_2"], "id_2": "id_1"}),
("individuals", None), # all pairs
("keypoints", {"key1": ["key2"]}), # list input
("keypoints", {"key1": "key2"}), # string input
("keypoints", {"key1": ["key2", "key3"], "key2": "key3"}),
("keypoints", {"centroid": ["left"]}), # list input
("keypoints", {"centroid": "left"}), # string input
("keypoints", {"centroid": ["left"], "left": "right"}),
("keypoints", None), # all pairs
],
)
def test_compute_pairwise_distances_with_valid_pairs(
pairwise_distances_dataset, dim, pairs
valid_poses_dataset_uniform_linear_motion, dim, pairs
):
"""Test that the expected pairwise distances are computed
for valid ``pairs`` inputs.
"""
result = getattr(kinematics, f"compute_inter{dim[:-1]}_distances")(
pairwise_distances_dataset.position, pairs=pairs
valid_poses_dataset_uniform_linear_motion.position, pairs=pairs
)
expected_data_vars = expected_pairwise_distances(
pairs, pairwise_distances_dataset, dim
pairs, valid_poses_dataset_uniform_linear_motion, dim
)
if isinstance(result, dict):
assert set(result.keys()) == set(expected_data_vars)
Expand All @@ -364,10 +367,10 @@ def test_compute_pairwise_distances_with_valid_pairs(


def test_compute_pairwise_distances_with_invalid_dim(
pairwise_distances_dataset,
valid_poses_dataset_uniform_linear_motion,
):
"""Test that an error is raised when an invalid dimension is passed."""
with pytest.raises(ValueError):
kinematics._compute_pairwise_distances(
pairwise_distances_dataset.position, "invalid_dim"
valid_poses_dataset_uniform_linear_motion.position, "invalid_dim"
)

0 comments on commit a44615e

Please sign in to comment.