Skip to content

Commit

Permalink
Add SuggestionFrame
Browse files Browse the repository at this point in the history
  • Loading branch information
talmo committed May 1, 2024
1 parent 61ab2d3 commit 5a0607d
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 1 deletion.
1 change: 1 addition & 0 deletions sleap_io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Instance,
PredictedInstance,
)
from sleap_io.model.suggestions import SuggestionFrame
from sleap_io.model.labeled_frame import LabeledFrame
from sleap_io.model.labels import Labels
from sleap_io.io.main import (
Expand Down
53 changes: 53 additions & 0 deletions sleap_io/io/slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Symmetry,
Node,
Track,
SuggestionFrame,
Point,
PredictedPoint,
Instance,
Expand Down Expand Up @@ -161,6 +162,55 @@ def write_tracks(labels_path: str, tracks: list[Track]):
f.create_dataset("tracks_json", data=tracks_json, maxshape=(None,))


def read_suggestions(labels_path: str, videos: list[Video]) -> list[SuggestionFrame]:
"""Read `SuggestionFrame` dataset in a SLEAP labels file.
Args:
labels_path: A string path to the SLEAP labels file.
videos: A list of `Video` objects.
Returns:
A list of `SuggestionFrame` objects.
"""
suggestions = [
json.loads(x) for x in read_hdf5_dataset(labels_path, "suggestions_json")
]
suggestions_objects = []
for suggestion in suggestions:
suggestions_objects.append(
SuggestionFrame(
video=videos[int(suggestion["video"])],
frame_idx=suggestion["frame_idx"],
)
)
return suggestions_objects


def write_suggestions(
labels_path: str, suggestions: list[SuggestionFrame], videos: list[Video]
):
"""Write track metadata to a SLEAP labels file.
Args:
labels_path: A string path to the SLEAP labels file.
suggestions: A list of `SuggestionFrame` objects to store the metadata for.
videos: A list of `Video` objects.
"""
GROUP = 0 # TODO: Handle storing extraneous metadata.
suggestions_json = []
for suggestion in suggestions:
suggestion_dict = {
"video": str(videos.index(suggestion.video)),
"frame_idx": suggestion.frame_idx,
"group": GROUP,
}
suggestion_json = np.string_(json.dumps(suggestion_dict, separators=(",", ":")))
suggestions_json.append(suggestion_json)

with h5py.File(labels_path, "a") as f:
f.create_dataset("suggestions_json", data=suggestions_json, maxshape=(None,))


def read_metadata(labels_path: str) -> dict:
"""Read metadata from a SLEAP labels file.
Expand Down Expand Up @@ -649,6 +699,7 @@ def read_labels(labels_path: str) -> Labels:
instances = read_instances(
labels_path, skeletons, tracks, points, pred_points, format_id
)
suggestions = read_suggestions(labels_path, videos)
metadata = read_metadata(labels_path)
provenance = metadata.get("provenance", dict())

Expand All @@ -668,6 +719,7 @@ def read_labels(labels_path: str) -> Labels:
videos=videos,
skeletons=skeletons,
tracks=tracks,
suggestions=suggestions,
provenance=provenance,
)

Expand All @@ -685,5 +737,6 @@ def write_labels(labels_path: str, labels: Labels):
Path(labels_path).unlink()
write_videos(labels_path, labels.videos)
write_tracks(labels_path, labels.tracks)
write_suggestions(labels_path, labels.suggestions, labels.videos)
write_metadata(labels_path, labels)
write_lfs(labels_path, labels)
11 changes: 10 additions & 1 deletion sleap_io/model/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@
"""

from __future__ import annotations
from sleap_io import LabeledFrame, Instance, PredictedInstance, Video, Track
from sleap_io import (
LabeledFrame,
Instance,
PredictedInstance,
Video,
Track,
SuggestionFrame,
)
from attrs import define, field
from typing import Union, Optional, Any
import numpy as np
Expand All @@ -32,6 +39,7 @@ class Labels:
skeletons: A list of `Skeleton`s that are associated with this dataset. This
should generally only contain a single skeleton.
tracks: A list of `Track`s that are associated with this dataset.
suggestions: A list of `SuggestionFrame`s that are associated with this dataset.
provenance: Dictionary of arbitrary metadata providing additional information
about where the dataset came from.
Expand All @@ -44,6 +52,7 @@ class Labels:
videos: list[Video] = field(factory=list)
skeletons: list[Skeleton] = field(factory=list)
tracks: list[Track] = field(factory=list)
suggestions: list[SuggestionFrame] = field(factory=list)
provenance: dict[str, Any] = field(factory=dict)

def __attrs_post_init__(self):
Expand Down
19 changes: 19 additions & 0 deletions sleap_io/model/suggestions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Data module for suggestions."""

from __future__ import annotations
from sleap_io.model.video import Video
import attrs
from typing import Optional


@attrs.define(auto_attribs=True)
class SuggestionFrame:
"""Data structure for a single frame of suggestions.
Attributes:
video: The video associated with the frame.
frame_idx: The index of the frame in the video.
"""

video: Video
frame_idx: int
1 change: 1 addition & 0 deletions tests/io/test_slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def test_write_labels(centered_pair, slp_real_data, tmp_path):
assert len(saved_labels.skeletons) == len(labels.skeletons) == 1
assert saved_labels.skeleton.name == labels.skeleton.name
assert saved_labels.skeleton.node_names == labels.skeleton.node_names
assert len(saved_labels.suggestions) == len(labels.suggestions)


def test_load_multi_skeleton(tmpdir):
Expand Down

0 comments on commit 5a0607d

Please sign in to comment.