Skip to content

Commit

Permalink
Fix suggestions deserialization
Browse files Browse the repository at this point in the history
  • Loading branch information
talmo committed May 22, 2024
1 parent 45941db commit 05f587a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
8 changes: 5 additions & 3 deletions sleap_io/io/slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,11 @@ def read_suggestions(labels_path: str, videos: list[Video]) -> list[SuggestionFr
Returns:
A list of `SuggestionFrame` objects.
"""
suggestions = [
json.loads(x) for x in read_hdf5_dataset(labels_path, "suggestions_json")
]
try:
suggestions = read_hdf5_dataset(labels_path, "suggestions_json")
except KeyError:
return []
suggestions = [json.loads(x) for x in suggestions]
suggestions_objects = []
for suggestion in suggestions:
suggestions_objects.append(
Expand Down
20 changes: 20 additions & 0 deletions tests/io/test_slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
PredictedPoint,
PredictedInstance,
Labels,
SuggestionFrame,
)
from sleap_io.io.slp import (
read_videos,
Expand All @@ -29,6 +30,8 @@
write_lfs,
read_labels,
write_labels,
read_suggestions,
write_suggestions,
)
from sleap_io.io.utils import read_hdf5_dataset
import numpy as np
Expand Down Expand Up @@ -237,3 +240,20 @@ def test_slp_imgvideo(tmpdir, slp_imgvideo):
assert type(videos[0].backend) == ImageVideo
assert len(videos[0].filename) == 2
assert videos[0].shape is None


def test_suggestions(tmpdir):
labels = Labels()
labels.videos.append(Video.from_filename("fake.mp4"))
labels.suggestions.append(SuggestionFrame(video=labels.video, frame_idx=0))

write_suggestions(tmpdir / "test.slp", labels.suggestions, labels.videos)
loaded_suggestions = read_suggestions(tmpdir / "test.slp", labels.videos)
assert len(loaded_suggestions) == 1
assert loaded_suggestions[0].video.filename == "fake.mp4"
assert loaded_suggestions[0].frame_idx == 0

# Handle missing suggestions dataset
write_videos(tmpdir / "test2.slp", labels.videos)
loaded_suggestions = read_suggestions(tmpdir / "test2.slp", labels.videos)
assert len(loaded_suggestions) == 0

0 comments on commit 05f587a

Please sign in to comment.