Skip to content

Commit

Permalink
Refactor pairwise distances tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lochhh committed Aug 28, 2024
1 parent 98a49ba commit 5813ffa
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 38 deletions.
21 changes: 18 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,7 @@ def valid_bboxes_array():


@pytest.fixture
def valid_bboxes_dataset(
valid_bboxes_array,
):
def valid_bboxes_dataset(valid_bboxes_array):
"""Return a valid bboxes dataset with low confidence values and
time in frames.
"""
Expand Down Expand Up @@ -464,12 +462,29 @@ def missing_two_dims_bboxes_dataset(valid_bboxes_dataset):
return valid_bboxes_dataset.rename({"time": "tame", "space": "spice"})


# --------------------------- Kinematics fixtures ---------------------------
@pytest.fixture(params=["displacement", "velocity", "acceleration"])
def kinematic_property(request):
"""Return a kinematic property."""
return request.param


@pytest.fixture
def pairwise_distances_dataset(valid_poses_dataset):
"""Return a dataset in which the positions of either ``ind2`` or ``key2``
is offset by 1 unit (for testing pairwise distances computation).
"""

def _pairwise_distances_dataset(dim):
elem_name = f"{dim[:3]}2"
valid_poses_dataset.position.loc[{dim: elem_name}] = (
valid_poses_dataset.position.sel({dim: elem_name}) + 1
)
return valid_poses_dataset

return _pairwise_distances_dataset


# ---------------- VIA tracks CSV file fixtures ----------------------------
@pytest.fixture
def via_tracks_csv_with_invalid_header(tmp_path):
Expand Down
78 changes: 43 additions & 35 deletions tests/test_unit/test_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,23 +112,45 @@ def test_approximate_derivative_with_invalid_order(self, order):
with pytest.raises(expected_exception):
kinematics._compute_approximate_time_derivative(data, order=order)

def _generate_expected_data_vars(self, pairs, valid_poses_dataset, dim):
"""Generate expected data variables for pairwise distances."""
def expected_pairwise_distances(self, pairs, input_ds, dim):
"""Return a dictionary containing the expected data variable
names mapped to the expected data array for pairwise distances tests.
"""
expected_coord = "keypoints" if dim == "individuals" else "individuals"

def _expected_dataarray(fill_value):
return xr.full_like(
xr.DataArray(
coords={
"time": input_ds.time,
expected_coord: getattr(input_ds, expected_coord),
},
dims=["time", expected_coord],
),
fill_value=fill_value,
)

if pairs is None:
return [
f"dist_{elem1}_{elem2}"
for elem1, elem2 in itertools.combinations(
getattr(valid_poses_dataset, dim).values, 2
)
]
paired_elements = list(
itertools.combinations(getattr(input_ds, dim).values, 2)
)
else:
return [
f"dist_{elem1}_{elem2}"
paired_elements = [
(elem1, elem2)
for elem1, elem2_list in pairs.items()
for elem2 in (
[elem2_list] if isinstance(elem2_list, str) else elem2_list
)
]
expected_data = {
f"dist_{elem1}_{elem2}": _expected_dataarray(
0 if elem1 == elem2 else np.sqrt(2)
)
for elem1, elem2 in paired_elements
}
if len(expected_data) == 1:
return next(iter(expected_data.values()))
return expected_data

@pytest.mark.parametrize(
"dim, pairs",
Expand All @@ -143,34 +165,20 @@ def _generate_expected_data_vars(self, pairs, valid_poses_dataset, dim):
("keypoints", None), # all pairs
],
)
def test_compute_pairwise_distances(self, valid_poses_dataset, dim, pairs):
"""Test interkeypoint distances computation."""
expected_coord = "keypoints" if dim == "individuals" else "individuals"
expected_result = xr.zeros_like(
xr.DataArray(
coords={
"time": valid_poses_dataset.time,
expected_coord: getattr(
valid_poses_dataset, expected_coord
),
},
dims=["time", expected_coord],
)
)
def test_compute_pairwise_distances(
self, pairwise_distances_dataset, dim, pairs
):
"""Test pairwise distances computation."""
input_ds = pairwise_distances_dataset(dim)
result = getattr(kinematics, f"compute_inter{dim[:-1]}_distances")(
valid_poses_dataset.position, pairs=pairs
input_ds.position, pairs=pairs
)
expected_result = self.expected_pairwise_distances(
pairs, input_ds, dim
)
if isinstance(result, dict):
expected_data_vars = self._generate_expected_data_vars(
pairs, valid_poses_dataset, dim
)
# Assert expected pairs are present in the result
# and the results are zeros
for expected in expected_data_vars:
xr.testing.assert_equal(
result[expected],
expected_result,
)
for key in result:
xr.testing.assert_equal(result[key], expected_result[key])
else: # single DataArray
xr.testing.assert_equal(
result,
Expand Down

0 comments on commit 5813ffa

Please sign in to comment.