From 5a0607d14ed6c82d4e38f9fadd96b624b51d49bd Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Tue, 30 Apr 2024 22:32:56 -0700 Subject: [PATCH] Add SuggestionFrame --- sleap_io/__init__.py | 1 + sleap_io/io/slp.py | 53 +++++++++++++++++++++++++++++++++++ sleap_io/model/labels.py | 11 +++++++- sleap_io/model/suggestions.py | 19 +++++++++++++ tests/io/test_slp.py | 1 + 5 files changed, 84 insertions(+), 1 deletion(-) create mode 100644 sleap_io/model/suggestions.py diff --git a/sleap_io/__init__.py b/sleap_io/__init__.py index 2b4e8d22..f3ab8dac 100644 --- a/sleap_io/__init__.py +++ b/sleap_io/__init__.py @@ -10,6 +10,7 @@ Instance, PredictedInstance, ) +from sleap_io.model.suggestions import SuggestionFrame from sleap_io.model.labeled_frame import LabeledFrame from sleap_io.model.labels import Labels from sleap_io.io.main import ( diff --git a/sleap_io/io/slp.py b/sleap_io/io/slp.py index 5499d05b..4d31a577 100644 --- a/sleap_io/io/slp.py +++ b/sleap_io/io/slp.py @@ -12,6 +12,7 @@ Symmetry, Node, Track, + SuggestionFrame, Point, PredictedPoint, Instance, @@ -161,6 +162,55 @@ def write_tracks(labels_path: str, tracks: list[Track]): f.create_dataset("tracks_json", data=tracks_json, maxshape=(None,)) +def read_suggestions(labels_path: str, videos: list[Video]) -> list[SuggestionFrame]: + """Read `SuggestionFrame` dataset in a SLEAP labels file. + + Args: + labels_path: A string path to the SLEAP labels file. + videos: A list of `Video` objects. + + Returns: + A list of `SuggestionFrame` objects. + """ + suggestions = [ + json.loads(x) for x in read_hdf5_dataset(labels_path, "suggestions_json") + ] + suggestions_objects = [] + for suggestion in suggestions: + suggestions_objects.append( + SuggestionFrame( + video=videos[int(suggestion["video"])], + frame_idx=suggestion["frame_idx"], + ) + ) + return suggestions_objects + + +def write_suggestions( + labels_path: str, suggestions: list[SuggestionFrame], videos: list[Video] +): + """Write track metadata to a SLEAP labels file. + + Args: + labels_path: A string path to the SLEAP labels file. + suggestions: A list of `SuggestionFrame` objects to store the metadata for. + videos: A list of `Video` objects. + """ + GROUP = 0 # TODO: Handle storing extraneous metadata. + suggestions_json = [] + for suggestion in suggestions: + suggestion_dict = { + "video": str(videos.index(suggestion.video)), + "frame_idx": suggestion.frame_idx, + "group": GROUP, + } + suggestion_json = np.string_(json.dumps(suggestion_dict, separators=(",", ":"))) + suggestions_json.append(suggestion_json) + + with h5py.File(labels_path, "a") as f: + f.create_dataset("suggestions_json", data=suggestions_json, maxshape=(None,)) + + def read_metadata(labels_path: str) -> dict: """Read metadata from a SLEAP labels file. @@ -649,6 +699,7 @@ def read_labels(labels_path: str) -> Labels: instances = read_instances( labels_path, skeletons, tracks, points, pred_points, format_id ) + suggestions = read_suggestions(labels_path, videos) metadata = read_metadata(labels_path) provenance = metadata.get("provenance", dict()) @@ -668,6 +719,7 @@ def read_labels(labels_path: str) -> Labels: videos=videos, skeletons=skeletons, tracks=tracks, + suggestions=suggestions, provenance=provenance, ) @@ -685,5 +737,6 @@ def write_labels(labels_path: str, labels: Labels): Path(labels_path).unlink() write_videos(labels_path, labels.videos) write_tracks(labels_path, labels.tracks) + write_suggestions(labels_path, labels.suggestions, labels.videos) write_metadata(labels_path, labels) write_lfs(labels_path, labels) diff --git a/sleap_io/model/labels.py b/sleap_io/model/labels.py index ba46e693..a9b979b2 100644 --- a/sleap_io/model/labels.py +++ b/sleap_io/model/labels.py @@ -12,7 +12,14 @@ """ from __future__ import annotations -from sleap_io import LabeledFrame, Instance, PredictedInstance, Video, Track +from sleap_io import ( + LabeledFrame, + Instance, + PredictedInstance, + Video, + Track, + SuggestionFrame, +) from attrs import define, field from typing import Union, Optional, Any import numpy as np @@ -32,6 +39,7 @@ class Labels: skeletons: A list of `Skeleton`s that are associated with this dataset. This should generally only contain a single skeleton. tracks: A list of `Track`s that are associated with this dataset. + suggestions: A list of `SuggestionFrame`s that are associated with this dataset. provenance: Dictionary of arbitrary metadata providing additional information about where the dataset came from. @@ -44,6 +52,7 @@ class Labels: videos: list[Video] = field(factory=list) skeletons: list[Skeleton] = field(factory=list) tracks: list[Track] = field(factory=list) + suggestions: list[SuggestionFrame] = field(factory=list) provenance: dict[str, Any] = field(factory=dict) def __attrs_post_init__(self): diff --git a/sleap_io/model/suggestions.py b/sleap_io/model/suggestions.py new file mode 100644 index 00000000..9fafa8b8 --- /dev/null +++ b/sleap_io/model/suggestions.py @@ -0,0 +1,19 @@ +"""Data module for suggestions.""" + +from __future__ import annotations +from sleap_io.model.video import Video +import attrs +from typing import Optional + + +@attrs.define(auto_attribs=True) +class SuggestionFrame: + """Data structure for a single frame of suggestions. + + Attributes: + video: The video associated with the frame. + frame_idx: The index of the frame in the video. + """ + + video: Video + frame_idx: int diff --git a/tests/io/test_slp.py b/tests/io/test_slp.py index 7a4b0973..399674c3 100644 --- a/tests/io/test_slp.py +++ b/tests/io/test_slp.py @@ -186,6 +186,7 @@ def test_write_labels(centered_pair, slp_real_data, tmp_path): assert len(saved_labels.skeletons) == len(labels.skeletons) == 1 assert saved_labels.skeleton.name == labels.skeleton.name assert saved_labels.skeleton.node_names == labels.skeleton.node_names + assert len(saved_labels.suggestions) == len(labels.suggestions) def test_load_multi_skeleton(tmpdir):