Skip to content

Commit

Permalink
Add Labels.find and Labels.video (#35)
Browse files Browse the repository at this point in the history
* Add basic implementation of labels.find and labels.video

* Bump version
  • Loading branch information
talmo authored Jun 23, 2023
1 parent 45eaf98 commit 4e6940b
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 3 deletions.
2 changes: 1 addition & 1 deletion sleap_io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Define package version.
# This is read dynamically by setuptools in pyproject.toml to determine the release version.
__version__ = "0.0.3"
__version__ = "0.0.4"

from sleap_io.model.skeleton import Node, Edge, Skeleton, Symmetry
from sleap_io.model.video import Video
Expand Down
61 changes: 60 additions & 1 deletion sleap_io/model/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __attrs_post_init__(self):
if inst.track is not None and inst.track not in self.tracks:
self.tracks.append(inst.track)

def __getitem__(self, key: int) -> Union[list[LabeledFrame], LabeledFrame]:
def __getitem__(self, key: int) -> list[LabeledFrame] | LabeledFrame:
"""Return one or more labeled frames based on indexing criteria."""
if type(key) == int:
return self.labeled_frames[key]
Expand Down Expand Up @@ -176,3 +176,62 @@ def numpy(
tracks[i, j] = inst.numpy(scores=return_confidence)

return tracks

@property
def video(self) -> Video:
"""Return the video if there is only a single video in the labels."""
if len(self.videos) == 0:
raise ValueError("There are no videos in the labels.")
elif len(self.videos) == 1:
return self.videos[0]
else:
raise ValueError(
"Labels.video can only be used when there is only a single video saved "
"in the labels. Use Labels.videos instead."
)

def find(
self,
video: Video,
frame_idx: int | list[int] | None = None,
return_new: bool = False,
) -> list[LabeledFrame]:
"""Search for labeled frames given video and/or frame index.
Args:
video: A `Video` that is associated with the project.
frame_idx: The frame index (or indices) which we want to find in the video.
If a range is specified, we'll return all frames with indices in that
range. If not specific, then we'll return all labeled frames for video.
return_new: Whether to return singleton of new and empty `LabeledFrame` if
none are found in project.
Returns:
List of `LabeledFrame` objects that match the criteria.
The list will be empty if no matches found, unless return_new is True,
in which case it contains new (empty) `LabeledFrame` objects with `video`
and `frame_index` set.
"""
results = []

if frame_idx is None:
for lf in self.labeled_frames:
if lf.video == video:
results.append(lf)
return results

if np.isscalar(frame_idx):
frame_idx = np.array(frame_idx).reshape(-1)

for frame_ind in frame_idx:
result = None
for lf in self.labeled_frames:
if lf.video == video and lf.frame_idx == frame_ind:
result = lf
results.append(result)
break
if result is None and return_new:
results.append(LabeledFrame(video=video, frame_idx=frame_ind))

return results
46 changes: 45 additions & 1 deletion tests/model/test_labels.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
"""Test methods and functions in the sleap_io.model.labels file."""
from numpy.testing import assert_equal
import pytest
from sleap_io import Video, Skeleton, Instance, PredictedInstance, LabeledFrame
from sleap_io import (
Video,
Skeleton,
Instance,
PredictedInstance,
LabeledFrame,
load_slp,
)
from sleap_io.model.labels import Labels


Expand Down Expand Up @@ -57,3 +64,40 @@ def test_labels_numpy(labels_predictions: Labels):
inst.track = None
labels_predictions.tracks = []
assert labels_predictions.numpy(untracked=False).shape == (1100, 0, 24, 2)


def test_labels_find(slp_typical):
labels = load_slp(slp_typical)

results = labels.find(video=labels.video, frame_idx=0)
assert len(results) == 1
lf = results[0]
assert lf.frame_idx == 0

labels.labeled_frames.append(LabeledFrame(video=labels.video, frame_idx=1))

results = labels.find(video=labels.video)
assert len(results) == 2

results = labels.find(video=labels.video, frame_idx=2)
assert len(results) == 0

results = labels.find(video=labels.video, frame_idx=2, return_new=True)
assert len(results) == 1
assert results[0].frame_idx == 2
assert len(results[0]) == 0


def test_labels_video():
labels = Labels()

with pytest.raises(ValueError):
labels.video

vid = Video(filename="test")
labels.videos.append(vid)
assert labels.video == vid

labels.videos.append(Video(filename="test2"))
with pytest.raises(ValueError):
labels.video

0 comments on commit 4e6940b

Please sign in to comment.