Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute Heading #315

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 88 additions & 4 deletions movement/analysis/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
from typing import Literal

import numpy as np
import numpy.typing as npt
import xarray as xr

from movement.utils.logging import log_error
from movement.utils.vector import compute_norm
from movement.utils.vector import (
convert_to_unit,
signed_angle_between_2d_vectors,
)
from movement.validators.arrays import validate_dims_coords


Expand Down Expand Up @@ -176,7 +180,7 @@
left_keypoint: str,
right_keypoint: str,
camera_view: Literal["top_down", "bottom_up"] = "top_down",
):
) -> xr.DataArray:
"""Compute a 2D forward vector given two left-right symmetric keypoints.

The forward vector is computed as a vector perpendicular to the
Expand Down Expand Up @@ -278,15 +282,15 @@

# Return unit vector

return forward_vector / compute_norm(forward_vector)
return convert_to_unit(forward_vector)


def compute_head_direction_vector(
data: xr.DataArray,
left_keypoint: str,
right_keypoint: str,
camera_view: Literal["top_down", "bottom_up"] = "top_down",
):
) -> xr.DataArray:
"""Compute the 2D head direction vector given two keypoints on the head.

This function is an alias for :func:`compute_forward_vector()\
Expand Down Expand Up @@ -324,6 +328,86 @@
)


def compute_heading(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder to self: rename this to compute_heading_angle()?

data: xr.DataArray,
left_keypoint: str,
right_keypoint: str,
reference_vector: npt.NDArray | list | tuple = (1, 0),
camera_view: Literal["top_down", "bottom_up"] = "top_down",
in_radians=False,
) -> xr.DataArray:
"""Compute the 2D heading given two keypoints on the head.

Heading is defined as the signed angle between the animal's forward
vector (see :func:`compute_forward_direction()\
<movement.analysis.kinematics.compute_forward_direction>`)
and a reference vector. By default, the reference vector
corresponds to the direction of the positive x-axis.

Parameters
----------
data : xarray.DataArray
The input data representing position. This must contain
the two symmetrical keypoints located on the left and
right sides of the body, respectively.
left_keypoint : str
Name of the left keypoint, e.g., "left_ear"
right_keypoint : str
Name of the right keypoint, e.g., "right_ear"
reference_vector : ndt.NDArray | list | tuple, optional
The reference vector against which the ```forward_vector`` is
compared to compute 2D heading. Must be a two-dimensional vector,
in the form [x,y] - where reference_vector[0] corresponds to the
x-coordinate and reference_vector[1] corresponds to the
y-coordinate. If left unspecified, the vector [1, 0] is used by
default.
camera_view : Literal["top_down", "bottom_up"], optional
The camera viewing angle, used to determine the upwards
direction of the animal. Can be either ``"top_down"`` (where the
upwards direction is [0, 0, -1]), or ``"bottom_up"`` (where the
upwards direction is [0, 0, 1]). If left unspecified, the camera
view is assumed to be ``"top_down"``.
in_radians : bool, optional
If true, the returned heading array is given in radians.
If false, the array is given in degrees. False by default.

Returns
-------
xarray.DataArray
An xarray DataArray containing the computed heading
timeseries, with dimensions matching the input data array,
but without the ``keypoints`` and ``space`` dimensions.

"""
# Convert reference vector to np.array if list or tuple
if isinstance(reference_vector, (list | tuple)):
reference_vector = np.array(reference_vector)

Check warning on line 384 in movement/analysis/kinematics.py

View check run for this annotation

Codecov / codecov/patch

movement/analysis/kinematics.py#L383-L384

Added lines #L383 - L384 were not covered by tests

# Validate that reference vector has correct dimensionality
if reference_vector.shape != (2,):
raise log_error(

Check warning on line 388 in movement/analysis/kinematics.py

View check run for this annotation

Codecov / codecov/patch

movement/analysis/kinematics.py#L387-L388

Added lines #L387 - L388 were not covered by tests
ValueError,
f"Reference vector must be two-dimensional (with"
f" shape ``(2,)``), but got {reference_vector.shape}.",
)

# Compute forward vector
forward_vector = compute_forward_vector(

Check warning on line 395 in movement/analysis/kinematics.py

View check run for this annotation

Codecov / codecov/patch

movement/analysis/kinematics.py#L395

Added line #L395 was not covered by tests
data, left_keypoint, right_keypoint, camera_view=camera_view
)

# Compute signed angle between forward vector and reference vector
heading_array = signed_angle_between_2d_vectors(

Check warning on line 400 in movement/analysis/kinematics.py

View check run for this annotation

Codecov / codecov/patch

movement/analysis/kinematics.py#L400

Added line #L400 was not covered by tests
forward_vector, reference_vector
)

# Convert to degrees
if not in_radians:
heading_array = np.rad2deg(heading_array)

Check warning on line 406 in movement/analysis/kinematics.py

View check run for this annotation

Codecov / codecov/patch

movement/analysis/kinematics.py#L405-L406

Added lines #L405 - L406 were not covered by tests

return heading_array

Check warning on line 408 in movement/analysis/kinematics.py

View check run for this annotation

Codecov / codecov/patch

movement/analysis/kinematics.py#L408

Added line #L408 was not covered by tests


def _validate_type_data_array(data: xr.DataArray) -> None:
"""Validate the input data is an xarray DataArray.

Expand Down
165 changes: 165 additions & 0 deletions movement/utils/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,174 @@
).transpose(*dims)


def signed_angle_between_2d_vectors(
test_vector: xr.DataArray, reference_vector: xr.DataArray | np.ndarray
) -> xr.DataArray:
"""Compute the signed angle between two 2-D vectors.

Parameters
----------
test_vector : xarray.DataArray
An array of position vectors containing the ``space``
dimension with only ``"x"`` and ``"y"`` coordinates.
reference_vector : xarray.DataArray | numpy.ndarray
A 2D vector (or array of 2D vectors) against which to
compare ``test_vector``. May either be an xarray
DataArray containing the ``space`` dimension or a numpy
array containing one or more 2D vectors. (See Notes)

Returns
-------
xarray.DataArray :
An xarray DataArray containing signed angle between
``test_vector`` and ``reference_vector`` for every
time-point. Matches the dimensions of ``test_vector``,
but without the ``space`` dimension.

Notes
-----
If passed as an xarray DataArray, the reference vector must
have the spatial coordinates ``x`` and ``y`` only, and must
have a ``time`` dimension matching that of the test vector.

If passed as a numpy array, the reference vector must have
one of three shapes:
1. ``(2,)`` - Where dimension ``0`` contains spatial
coordinates (x,y), and no time dimension is specified.
2. ``(1,2)`` - Where dimension ``0`` corresponds to a
single time-point and dimension ``1`` contains spatial
coordinates (x,y).
3. ``(n,2)`` - Where dimension ``0`` corresponds to
time and dimension ``1`` contains spatial coordinates
(x,y), and where ``n == len(test_vector.time)``.

Reference vectors containing more dimensions, or with shapes
otherwise different from those defined above are considered
invalid.

"""
if isinstance(reference_vector, np.ndarray) and reference_vector.shape == (

Check warning on line 214 in movement/utils/vector.py

View check run for this annotation

Codecov / codecov/patch

movement/utils/vector.py#L214

Added line #L214 was not covered by tests
2,
):
reference_vector = reference_vector.reshape(1, 2)

Check warning on line 217 in movement/utils/vector.py

View check run for this annotation

Codecov / codecov/patch

movement/utils/vector.py#L217

Added line #L217 was not covered by tests

validate_dims_coords(test_vector, {"space": ["x", "y"]})
_validate_reference_vector(reference_vector, test_vector)

Check warning on line 220 in movement/utils/vector.py

View check run for this annotation

Codecov / codecov/patch

movement/utils/vector.py#L219-L220

Added lines #L219 - L220 were not covered by tests

test_unit = convert_to_unit(test_vector)
test_x = test_unit.sel(space="x")
test_y = test_unit.sel(space="y")

Check warning on line 224 in movement/utils/vector.py

View check run for this annotation

Codecov / codecov/patch

movement/utils/vector.py#L222-L224

Added lines #L222 - L224 were not covered by tests

if isinstance(reference_vector, xr.DataArray):
ref_unit = convert_to_unit(reference_vector)
ref_x = ref_unit.sel(space="x")
ref_y = ref_unit.sel(space="y")

Check warning on line 229 in movement/utils/vector.py

View check run for this annotation

Codecov / codecov/patch

movement/utils/vector.py#L226-L229

Added lines #L226 - L229 were not covered by tests
else:
ref_unit = reference_vector / np.linalg.norm(reference_vector)
ref_x = np.take(ref_unit, 0, axis=-1).reshape(-1, 1)
ref_y = np.take(ref_unit, 1, axis=-1).reshape(-1, 1)

Check warning on line 233 in movement/utils/vector.py

View check run for this annotation

Codecov / codecov/patch

movement/utils/vector.py#L231-L233

Added lines #L231 - L233 were not covered by tests

signed_angles = np.arctan2(

Check warning on line 235 in movement/utils/vector.py

View check run for this annotation

Codecov / codecov/patch

movement/utils/vector.py#L235

Added line #L235 was not covered by tests
test_y * ref_x - test_x * ref_y,
test_x * ref_x + test_y * ref_y,
)

return signed_angles

Check warning on line 240 in movement/utils/vector.py

View check run for this annotation

Codecov / codecov/patch

movement/utils/vector.py#L240

Added line #L240 was not covered by tests


def _raise_error_for_missing_spatial_dim() -> None:
raise log_error(
ValueError,
"Input data array must contain either 'space' or 'space_pol' "
"as dimensions.",
)


def _validate_reference_vector(
reference_vector: xr.DataArray | np.ndarray, test_vector: xr.DataArray
):
"""Validate the reference vector has the correct type and dimensions.

Parameters
----------
reference_vector : xarray.DataArray | numpy.ndarray
The reference vector array to validate.
test_vector : xarray.DataArray
The input data against which to validate the
reference vector.

Returns
-------
TypeError
If reference_vector is not an xarray DataArray or
a numpy array
ValueError
If reference_vector does not have the correct dimensions

"""
# Validate reference vector type
if not isinstance(reference_vector, (xr.DataArray | np.ndarray)):
raise log_error(

Check warning on line 275 in movement/utils/vector.py

View check run for this annotation

Codecov / codecov/patch

movement/utils/vector.py#L274-L275

Added lines #L274 - L275 were not covered by tests
TypeError,
f"Reference vector must be an xarray.DataArray or a np.ndarray, "
f"but got {type(reference_vector)}.",
)
if isinstance(reference_vector, xr.DataArray):
validate_dims_coords(

Check warning on line 281 in movement/utils/vector.py

View check run for this annotation

Codecov / codecov/patch

movement/utils/vector.py#L280-L281

Added lines #L280 - L281 were not covered by tests
reference_vector,
{
"space": ["x", "y"],
},
)
# Check reference_vector is 2D
if len(reference_vector.space) > 2:
raise log_error(

Check warning on line 289 in movement/utils/vector.py

View check run for this annotation

Codecov / codecov/patch

movement/utils/vector.py#L288-L289

Added lines #L288 - L289 were not covered by tests
ValueError,
"Reference vector may not have more than 2 spatial "
"coordinates.",
)
# Check reference vector has valid time dimension
if "time" in reference_vector.dims and not len(

Check warning on line 295 in movement/utils/vector.py

View check run for this annotation

Codecov / codecov/patch

movement/utils/vector.py#L295

Added line #L295 was not covered by tests
reference_vector.time
) == len(test_vector.time):
raise log_error(

Check warning on line 298 in movement/utils/vector.py

View check run for this annotation

Codecov / codecov/patch

movement/utils/vector.py#L298

Added line #L298 was not covered by tests
ValueError,
"Input data and reference vector must have matching time "
"dimensions.",
)
if any(dim not in ["time", "space"] for dim in reference_vector.dims):
raise log_error(

Check warning on line 304 in movement/utils/vector.py

View check run for this annotation

Codecov / codecov/patch

movement/utils/vector.py#L303-L304

Added lines #L303 - L304 were not covered by tests
ValueError, "Reference vector contains invalid dimensions."
)
else:
if not (reference_vector.dtype == int) or (

Check warning on line 308 in movement/utils/vector.py

View check run for this annotation

Codecov / codecov/patch

movement/utils/vector.py#L308

Added line #L308 was not covered by tests
reference_vector.dtype == float
):
raise log_error(

Check warning on line 311 in movement/utils/vector.py

View check run for this annotation

Codecov / codecov/patch

movement/utils/vector.py#L311

Added line #L311 was not covered by tests
ValueError,
"Reference vector may only contain values of type ``int``"
"or ``float``.",
)
if not (

Check warning on line 316 in movement/utils/vector.py

View check run for this annotation

Codecov / codecov/patch

movement/utils/vector.py#L316

Added line #L316 was not covered by tests
reference_vector.shape[0] == 1
or reference_vector.shape[0] == len(test_vector.time)
): # Validate time dim
raise log_error(

Check warning on line 320 in movement/utils/vector.py

View check run for this annotation

Codecov / codecov/patch

movement/utils/vector.py#L320

Added line #L320 was not covered by tests
ValueError,
"Dimension ``0`` of the reference vector must have length "
"``1`` or be equal in length to the ``time`` dimension of the "
"test vector.",
)
if not reference_vector.shape[-1] == 2: # Validate space dimension
raise log_error(

Check warning on line 327 in movement/utils/vector.py

View check run for this annotation

Codecov / codecov/patch

movement/utils/vector.py#L326-L327

Added lines #L326 - L327 were not covered by tests
ValueError,
"Dimension ``-1`` of the reference_vector must correspond to "
"coordinates in 2-D space, and may therefore only have size "
f"``2``. Instead, got size ``{reference_vector.shape[1]}``.",
)
if len(reference_vector.shape) > 2:
raise log_error(

Check warning on line 334 in movement/utils/vector.py

View check run for this annotation

Codecov / codecov/patch

movement/utils/vector.py#L333-L334

Added lines #L333 - L334 were not covered by tests
ValueError,
"Reference vector may not have more than 2 dimensions (time"
"and space, respectively)",
)
Loading