Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SuggestionFrame #89

Merged
merged 2 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sleap_io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Instance,
PredictedInstance,
)
from sleap_io.model.suggestions import SuggestionFrame
from sleap_io.model.labeled_frame import LabeledFrame
from sleap_io.model.labels import Labels
from sleap_io.io.main import (
Expand Down
53 changes: 53 additions & 0 deletions sleap_io/io/slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Symmetry,
Node,
Track,
SuggestionFrame,
Point,
PredictedPoint,
Instance,
Expand Down Expand Up @@ -161,6 +162,55 @@ def write_tracks(labels_path: str, tracks: list[Track]):
f.create_dataset("tracks_json", data=tracks_json, maxshape=(None,))


def read_suggestions(labels_path: str, videos: list[Video]) -> list[SuggestionFrame]:
"""Read `SuggestionFrame` dataset in a SLEAP labels file.

Args:
labels_path: A string path to the SLEAP labels file.
videos: A list of `Video` objects.

Returns:
A list of `SuggestionFrame` objects.
"""
suggestions = [
json.loads(x) for x in read_hdf5_dataset(labels_path, "suggestions_json")
]
suggestions_objects = []
for suggestion in suggestions:
suggestions_objects.append(
SuggestionFrame(
video=videos[int(suggestion["video"])],
frame_idx=suggestion["frame_idx"],
)
)
return suggestions_objects


def write_suggestions(
labels_path: str, suggestions: list[SuggestionFrame], videos: list[Video]
):
"""Write track metadata to a SLEAP labels file.

Args:
labels_path: A string path to the SLEAP labels file.
suggestions: A list of `SuggestionFrame` objects to store the metadata for.
videos: A list of `Video` objects.
"""
GROUP = 0 # TODO: Handle storing extraneous metadata.
suggestions_json = []
for suggestion in suggestions:
suggestion_dict = {
"video": str(videos.index(suggestion.video)),
"frame_idx": suggestion.frame_idx,
"group": GROUP,
}
suggestion_json = np.string_(json.dumps(suggestion_dict, separators=(",", ":")))
suggestions_json.append(suggestion_json)

with h5py.File(labels_path, "a") as f:
f.create_dataset("suggestions_json", data=suggestions_json, maxshape=(None,))


def read_metadata(labels_path: str) -> dict:
"""Read metadata from a SLEAP labels file.

Expand Down Expand Up @@ -649,6 +699,7 @@ def read_labels(labels_path: str) -> Labels:
instances = read_instances(
labels_path, skeletons, tracks, points, pred_points, format_id
)
suggestions = read_suggestions(labels_path, videos)
metadata = read_metadata(labels_path)
provenance = metadata.get("provenance", dict())

Expand All @@ -668,6 +719,7 @@ def read_labels(labels_path: str) -> Labels:
videos=videos,
skeletons=skeletons,
tracks=tracks,
suggestions=suggestions,
provenance=provenance,
)

Expand All @@ -685,5 +737,6 @@ def write_labels(labels_path: str, labels: Labels):
Path(labels_path).unlink()
write_videos(labels_path, labels.videos)
write_tracks(labels_path, labels.tracks)
write_suggestions(labels_path, labels.suggestions, labels.videos)
write_metadata(labels_path, labels)
write_lfs(labels_path, labels)
11 changes: 10 additions & 1 deletion sleap_io/model/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@
"""

from __future__ import annotations
from sleap_io import LabeledFrame, Instance, PredictedInstance, Video, Track
from sleap_io import (
LabeledFrame,
Instance,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider removing or using the unused import Instance.

- from sleap_io import Instance

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
Instance,

PredictedInstance,
Video,
Track,
SuggestionFrame,
)
from attrs import define, field
from typing import Union, Optional, Any
import numpy as np
Expand All @@ -32,6 +39,7 @@ class Labels:
skeletons: A list of `Skeleton`s that are associated with this dataset. This
should generally only contain a single skeleton.
tracks: A list of `Track`s that are associated with this dataset.
suggestions: A list of `SuggestionFrame`s that are associated with this dataset.
provenance: Dictionary of arbitrary metadata providing additional information
about where the dataset came from.

Expand All @@ -44,6 +52,7 @@ class Labels:
videos: list[Video] = field(factory=list)
skeletons: list[Skeleton] = field(factory=list)
tracks: list[Track] = field(factory=list)
suggestions: list[SuggestionFrame] = field(factory=list)
provenance: dict[str, Any] = field(factory=dict)

def __attrs_post_init__(self):
Expand Down
18 changes: 18 additions & 0 deletions sleap_io/model/suggestions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Data module for suggestions."""

from __future__ import annotations
from sleap_io.model.video import Video
import attrs


@attrs.define(auto_attribs=True)
class SuggestionFrame:
"""Data structure for a single frame of suggestions.

Attributes:
video: The video associated with the frame.
frame_idx: The index of the frame in the video.
"""

video: Video
frame_idx: int
1 change: 1 addition & 0 deletions tests/io/test_slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def test_write_labels(centered_pair, slp_real_data, tmp_path):
assert len(saved_labels.skeletons) == len(labels.skeletons) == 1
assert saved_labels.skeleton.name == labels.skeleton.name
assert saved_labels.skeleton.node_names == labels.skeleton.node_names
assert len(saved_labels.suggestions) == len(labels.suggestions)


def test_load_multi_skeleton(tmpdir):
Expand Down
Loading