Skip to content

Commit

Permalink
Labels.split and Labels.make_training_splits (#98)
Browse files Browse the repository at this point in the history
* Add "all" embed target, clean up embedding docs, dedup, tests

* Video typing things and cleanup

* Labels.split

* Add provenance (filename) at load time

* Provenance serialization

* Add provenance to Labels.split

* Add Labels.make_training_splits

* Update docs

* Bump version

* Test fixes
  • Loading branch information
talmo committed Jun 19, 2024
1 parent aaeb0f2 commit 5a30110
Show file tree
Hide file tree
Showing 13 changed files with 485 additions and 44 deletions.
75 changes: 72 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ import sleap_io as sio
labels = sio.load_file("predictions.slp")

# Save to NWB file.
sio.save_file(labels, "predictions.nwb")
# Or:
# labels.save("predictions.nwb")
labels.save("predictions.nwb")
```

**See also:** [`Labels.save`](model.md#sleap_io.Labels.save) and [Formats](formats.md)


### Convert labels to raw arrays

```py
Expand All @@ -60,6 +61,9 @@ n_frames, n_tracks, n_nodes, xy_score = trx.shape
assert xy_score == 3
```

**See also:** [`Labels.numpy`](model.md#sleap_io.Labels.numpy)


### Read video data

```py
Expand All @@ -72,6 +76,10 @@ frame = video[0]
height, width, channels = frame.shape
```

**See also:** [`sio.load_video`](formats.md#sleap_io.load_video) and [`Video`](model.md#sleap_io.Video)



### Create labels from raw data

```py
Expand Down Expand Up @@ -107,6 +115,67 @@ labels = sio.Labels(videos=[video], skeletons=[skeleton], labeled_frames=[lf])
labels.save("labels.slp")
```

**See also:** [Model](model.md), [`Labels`](model.md#sleap_io.Labels),
[`LabeledFrame`](model.md#sleap_io.LabeledFrame),
[`Instance`](model.md#sleap_io.Instance),
[`PredictedInstance`](model.md#sleap_io.PredictedInstance),
[`Skeleton`](model.md#sleap_io.Skeleton), [`Video`](model.md#sleap_io.Video), [`Track`](model.md#sleap_io.Track), [`SuggestionFrame`](model.md#sleap_io.SuggestionFrame)


### Fix video paths

```py
import sleap_io as sio

labels = sio.load_file("labels.v001.slp")

# Fix paths using prefixes.
labels.replace_filenames(prefix_map={
"D:/data/sleap_projects": "/home/user/sleap_projects",
"C:/Users/sleaper/Desktop/test": "/home/user/sleap_projects",
})

labels.save("labels.v002.slp")
```

**See also:** [`Labels.replace_filenames`](model.md#sleap_io.Labels.replace_filenames)


### Save labels with embedded images

```py
import sleap_io as sio

# Load source labels.
labels = sio.load_file("labels.v001.slp")

# Save with embedded images for frames with user labeled data and suggested frames.
labels.save("labels.v001.pkg.slp", embed="user+suggestions")
```

**See also:** [`Labels.save`](model.md#sleap_io.Labels.save)


### Make training/validation/test splits

```py
import sleap_io as sio

# Load source labels.
labels = sio.load_file("labels.v001.slp")

# Make splits and export with embedded images.
labels.make_training_splits(n_train=0.8, n_val=0.1, n_test=0.1, save_dir="split1", seed=42)

# Splits will be saved as self-contained SLP package files with images and labels.
labels_train = sio.load_file("split1/train.pkg.slp")
labels_val = sio.load_file("split1/val.pkg.slp")
labels_test = sio.load_file("split1/test.pkg.slp")
```

**See also:** [`Labels.make_training_splits`](model.md#sleap_io.Labels.make_training_splits)


## Support
For technical inquiries specific to this package, please [open an Issue](https://github.com/talmolab/sleap-io/issues)
with a description of your problem or request.
Expand Down
6 changes: 3 additions & 3 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,13 @@ extra_css:
- css/mkdocstrings.css

copyright: >
Copyright © 2024 - 2024 Talmo Lab –
Copyright © 2022 - 2024 Talmo Lab –
<a href="#__consent">Change cookie settings</a>
nav:
- Overview: index.md
- Changelog: https://github.com/talmolab/sleap-io/releases
- Core API:
- Data model: model.md
- Data formats: formats.md
- Model: model.md
- Formats: formats.md
- Full API: reference/
4 changes: 3 additions & 1 deletion sleap_io/io/jabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def read_labels(
instances.append(new_instance)
frame_label = LabeledFrame(video, frame_idx, instances)
frames.append(frame_label)
return Labels(frames)
labels = Labels(frames)
labels.provenance["filename"] = labels_path
return labels


def make_simple_skeleton(name: str, num_points: int) -> Skeleton:
Expand Down
4 changes: 3 additions & 1 deletion sleap_io/io/labelstudio.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def read_labels(
else:
assert isinstance(skeleton, Skeleton)

return parse_tasks(tasks, skeleton)
labels = parse_tasks(tasks, skeleton)
labels.provenance["filename"] = labels_path
return labels


def infer_nodes(tasks: List[Dict]) -> Skeleton:
Expand Down
22 changes: 17 additions & 5 deletions sleap_io/io/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,29 @@ def load_slp(filename: str) -> Labels:


def save_slp(
labels: Labels, filename: str, embed: str | list[tuple[Video, int]] | None = None
labels: Labels,
filename: str,
embed: bool | str | list[tuple[Video, int]] | None = None,
):
"""Save a SLEAP dataset to a `.slp` file.
Args:
labels: A SLEAP `Labels` object (see `load_slp`).
filename: Path to save labels to ending with `.slp`.
embed: One of `"user"`, `"suggestions"`, `"user+suggestions"`, `"source"` or
list of tuples of `(video, frame_idx)` specifying the frames to embed. If
`"source"` is specified, no images will be embedded and the source video
embed: Frames to embed in the saved labels file. One of `None`, `True`,
`"all"`, `"user"`, `"suggestions"`, `"user+suggestions"`, `"source"` or list
of tuples of `(video, frame_idx)`.
If `None` is specified (the default) and the labels contains embedded
frames, those embedded frames will be re-saved to the new file.
If `True` or `"all"`, all labeled frames and suggested frames will be
embedded.
If `"source"` is specified, no images will be embedded and the source video
will be restored if available.
This argument is only valid for the SLP backend.
"""
return slp.write_labels(filename, labels, embed=embed)

Expand Down Expand Up @@ -149,7 +161,7 @@ def load_file(
A `Labels` or `Video` object.
"""
if isinstance(filename, Path):
filename = str(filename)
filename = filename.as_posix()

if format is None:
if filename.endswith(".slp"):
Expand Down
1 change: 1 addition & 0 deletions sleap_io/io/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def read_nwb(path: str) -> Labels:
LabeledFrame(video=video, frame_idx=frame_idx, instances=insts)
)
labels = Labels(lfs)
labels.provenance["filename"] = path
return labels


Expand Down
61 changes: 50 additions & 11 deletions sleap_io/io/slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@ def embed_frames(
to_embed_by_video[video] = []
to_embed_by_video[video].append(frame_idx)

for video in to_embed_by_video:
to_embed_by_video[video] = np.unique(to_embed_by_video[video])

replaced_videos = {}
for video, frame_inds in to_embed_by_video.items():
video_ind = labels.videos.index(video)
Expand All @@ -348,25 +351,40 @@ def embed_frames(


def embed_videos(
labels_path: str, labels: Labels, embed: str | list[tuple[Video, int]]
labels_path: str, labels: Labels, embed: bool | str | list[tuple[Video, int]]
):
"""Embed videos in a SLEAP labels file.
Args:
labels_path: A string path to the SLEAP labels file to save.
labels: A `Labels` object to save.
embed: One of `"user"`, `"suggestions"`, `"user+suggestions"`, `"source"` or
list of tuples of `(video, frame_idx)` specifying the frames to embed. If
`"source"` is specified, no images will be embedded and the source video
embed: Frames to embed in the saved labels file. One of `None`, `True`,
`"all"`, `"user"`, `"suggestions"`, `"user+suggestions"`, `"source"` or list
of tuples of `(video, frame_idx)`.
If `None` is specified (the default) and the labels contains embedded
frames, those embedded frames will be re-saved to the new file.
If `True` or `"all"`, all labeled frames and suggested frames will be
embedded.
If `"source"` is specified, no images will be embedded and the source video
will be restored if available.
This argument is only valid for the SLP backend.
"""
if embed is True:
embed = "all"
if embed == "user":
embed = [(lf.video, lf.frame_idx) for lf in labels.user_labeled_frames]
elif embed == "suggestions":
embed = [(sf.video, sf.frame_idx) for sf in labels.suggestions]
elif embed == "user+suggestions":
embed = [(lf.video, lf.frame_idx) for lf in labels.user_labeled_frames]
embed += [(sf.video, sf.frame_idx) for sf in labels.suggestions]
elif embed == "all":
embed = [(lf.video, lf.frame_idx) for lf in labels]
embed += [(sf.video, sf.frame_idx) for sf in labels.suggestions]
elif embed == "source":
embed = []
elif isinstance(embed, list):
Expand Down Expand Up @@ -711,6 +729,13 @@ def write_metadata(labels_path: str, labels: Labels):
"negative_anchors": {},
"provenance": labels.provenance,
}

# Custom encoding.
for k in md["provenance"]:
if isinstance(md["provenance"][k], Path):
# Path -> str
md["provenance"][k] = md["provenance"][k].as_posix()

with h5py.File(labels_path, "a") as f:
grp = f.require_group("metadata")
grp.attrs["format_id"] = 1.2
Expand Down Expand Up @@ -942,7 +967,9 @@ def write_lfs(labels_path: str, labels: Labels):

# Link instances based on from_predicted field.
for instance_id, from_predicted in to_link:
instances[instance_id][5] = inst_to_id[id(from_predicted)]
# Source instance may be missing if predictions were removed from the labels, in
# which case, remove the link.
instances[instance_id][5] = inst_to_id.get(id(from_predicted), -1)

# Create structured arrays.
points = np.array([tuple(x) for x in points], dtype=point_dtype)
Expand Down Expand Up @@ -1013,23 +1040,35 @@ def read_labels(labels_path: str) -> Labels:
suggestions=suggestions,
provenance=provenance,
)
labels.provenance["filename"] = labels_path

return labels


def write_labels(
labels_path: str, labels: Labels, embed: str | list[tuple[Video, int]] | None = None
labels_path: str,
labels: Labels,
embed: bool | str | list[tuple[Video, int]] | None = None,
):
"""Write a SLEAP labels file.
Args:
labels_path: A string path to the SLEAP labels file to save.
labels: A `Labels` object to save.
embed: One of `"user"`, `"suggestions"`, `"user+suggestions"`, `"source"`,
`None` or list of tuples of `(video, frame_idx)` specifying the frames to
embed. If `"source"` is specified, no images will be embedded and the source
video will be restored if available. If `None` is specified (the default),
existing embedded images will be re-embedded.
embed: Frames to embed in the saved labels file. One of `None`, `True`,
`"all"`, `"user"`, `"suggestions"`, `"user+suggestions"`, `"source"` or list
of tuples of `(video, frame_idx)`.
If `None` is specified (the default) and the labels contains embedded
frames, those embedded frames will be re-saved to the new file.
If `True` or `"all"`, all labeled frames and suggested frames will be
embedded.
If `"source"` is specified, no images will be embedded and the source video
will be restored if available.
This argument is only valid for the SLP backend.
"""
if Path(labels_path).exists():
Path(labels_path).unlink()
Expand Down
2 changes: 1 addition & 1 deletion sleap_io/io/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def find_embedded(name, obj):
self.image_format = ds.attrs["format"]

if "frame_numbers" in ds.parent:
frame_numbers = ds.parent["frame_numbers"][:]
frame_numbers = ds.parent["frame_numbers"][:].astype(int)
self.frame_map = {frame: idx for idx, frame in enumerate(frame_numbers)}
self.source_inds = frame_numbers

Expand Down
Loading

0 comments on commit 5a30110

Please sign in to comment.