Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Labels.find and Labels.video #35

Merged
merged 2 commits into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sleap_io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Define package version.
# This is read dynamically by setuptools in pyproject.toml to determine the release version.
__version__ = "0.0.3"
__version__ = "0.0.4"

from sleap_io.model.skeleton import Node, Edge, Skeleton, Symmetry
from sleap_io.model.video import Video
Expand Down
61 changes: 60 additions & 1 deletion sleap_io/model/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __attrs_post_init__(self):
if inst.track is not None and inst.track not in self.tracks:
self.tracks.append(inst.track)

def __getitem__(self, key: int) -> Union[list[LabeledFrame], LabeledFrame]:
def __getitem__(self, key: int) -> list[LabeledFrame] | LabeledFrame:
"""Return one or more labeled frames based on indexing criteria."""
if type(key) == int:
return self.labeled_frames[key]
Expand Down Expand Up @@ -176,3 +176,62 @@ def numpy(
tracks[i, j] = inst.numpy(scores=return_confidence)

return tracks

@property
def video(self) -> Video:
"""Return the video if there is only a single video in the labels."""
if len(self.videos) == 0:
raise ValueError("There are no videos in the labels.")
elif len(self.videos) == 1:
return self.videos[0]
else:
raise ValueError(
"Labels.video can only be used when there is only a single video saved "
"in the labels. Use Labels.videos instead."
)

def find(
self,
video: Video,
frame_idx: int | list[int] | None = None,
return_new: bool = False,
) -> list[LabeledFrame]:
"""Search for labeled frames given video and/or frame index.

Args:
video: A `Video` that is associated with the project.
frame_idx: The frame index (or indices) which we want to find in the video.
If a range is specified, we'll return all frames with indices in that
range. If not specific, then we'll return all labeled frames for video.
return_new: Whether to return singleton of new and empty `LabeledFrame` if
none are found in project.

Returns:
List of `LabeledFrame` objects that match the criteria.

The list will be empty if no matches found, unless return_new is True,
in which case it contains new (empty) `LabeledFrame` objects with `video`
and `frame_index` set.
"""
results = []

if frame_idx is None:
for lf in self.labeled_frames:
if lf.video == video:
results.append(lf)
return results

if np.isscalar(frame_idx):
frame_idx = np.array(frame_idx).reshape(-1)

for frame_ind in frame_idx:
result = None
for lf in self.labeled_frames:
if lf.video == video and lf.frame_idx == frame_ind:
result = lf
results.append(result)
break
if result is None and return_new:
results.append(LabeledFrame(video=video, frame_idx=frame_ind))

return results
46 changes: 45 additions & 1 deletion tests/model/test_labels.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
"""Test methods and functions in the sleap_io.model.labels file."""
from numpy.testing import assert_equal
import pytest
from sleap_io import Video, Skeleton, Instance, PredictedInstance, LabeledFrame
from sleap_io import (
Video,
Skeleton,
Instance,
PredictedInstance,
LabeledFrame,
load_slp,
)
from sleap_io.model.labels import Labels


Expand Down Expand Up @@ -57,3 +64,40 @@ def test_labels_numpy(labels_predictions: Labels):
inst.track = None
labels_predictions.tracks = []
assert labels_predictions.numpy(untracked=False).shape == (1100, 0, 24, 2)


def test_labels_find(slp_typical):
labels = load_slp(slp_typical)

results = labels.find(video=labels.video, frame_idx=0)
assert len(results) == 1
lf = results[0]
assert lf.frame_idx == 0

labels.labeled_frames.append(LabeledFrame(video=labels.video, frame_idx=1))

results = labels.find(video=labels.video)
assert len(results) == 2

results = labels.find(video=labels.video, frame_idx=2)
assert len(results) == 0

results = labels.find(video=labels.video, frame_idx=2, return_new=True)
assert len(results) == 1
assert results[0].frame_idx == 2
assert len(results[0]) == 0


def test_labels_video():
labels = Labels()

with pytest.raises(ValueError):
labels.video

vid = Video(filename="test")
labels.videos.append(vid)
assert labels.video == vid

labels.videos.append(Video(filename="test2"))
with pytest.raises(ValueError):
labels.video