From 5a301106effd8cbcc6f23c9385e3c274b6b9262f Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Tue, 18 Jun 2024 19:36:01 -0700 Subject: [PATCH] `Labels.split` and `Labels.make_training_splits` (#98) * Add "all" embed target, clean up embedding docs, dedup, tests * Video typing things and cleanup * Labels.split * Add provenance (filename) at load time * Provenance serialization * Add provenance to Labels.split * Add Labels.make_training_splits * Update docs * Bump version * Test fixes --- docs/index.md | 75 +++++++++++++++- mkdocs.yml | 6 +- sleap_io/io/jabs.py | 4 +- sleap_io/io/labelstudio.py | 4 +- sleap_io/io/main.py | 22 +++-- sleap_io/io/nwb.py | 1 + sleap_io/io/slp.py | 61 ++++++++++--- sleap_io/io/video.py | 2 +- sleap_io/model/labels.py | 149 ++++++++++++++++++++++++++++++-- sleap_io/model/video.py | 7 +- sleap_io/version.py | 2 +- tests/io/test_slp.py | 25 ++++-- tests/model/test_labels.py | 171 +++++++++++++++++++++++++++++++++++++ 13 files changed, 485 insertions(+), 44 deletions(-) diff --git a/docs/index.md b/docs/index.md index 25327843..ccb3da85 100644 --- a/docs/index.md +++ b/docs/index.md @@ -37,11 +37,12 @@ import sleap_io as sio labels = sio.load_file("predictions.slp") # Save to NWB file. -sio.save_file(labels, "predictions.nwb") -# Or: -# labels.save("predictions.nwb") +labels.save("predictions.nwb") ``` +**See also:** [`Labels.save`](model.md#sleap_io.Labels.save) and [Formats](formats.md) + + ### Convert labels to raw arrays ```py @@ -60,6 +61,9 @@ n_frames, n_tracks, n_nodes, xy_score = trx.shape assert xy_score == 3 ``` +**See also:** [`Labels.numpy`](model.md#sleap_io.Labels.numpy) + + ### Read video data ```py @@ -72,6 +76,10 @@ frame = video[0] height, width, channels = frame.shape ``` +**See also:** [`sio.load_video`](formats.md#sleap_io.load_video) and [`Video`](model.md#sleap_io.Video) + + + ### Create labels from raw data ```py @@ -107,6 +115,67 @@ labels = sio.Labels(videos=[video], skeletons=[skeleton], labeled_frames=[lf]) labels.save("labels.slp") ``` +**See also:** [Model](model.md), [`Labels`](model.md#sleap_io.Labels), +[`LabeledFrame`](model.md#sleap_io.LabeledFrame), +[`Instance`](model.md#sleap_io.Instance), +[`PredictedInstance`](model.md#sleap_io.PredictedInstance), +[`Skeleton`](model.md#sleap_io.Skeleton), [`Video`](model.md#sleap_io.Video), [`Track`](model.md#sleap_io.Track), [`SuggestionFrame`](model.md#sleap_io.SuggestionFrame) + + +### Fix video paths + +```py +import sleap_io as sio + +labels = sio.load_file("labels.v001.slp") + +# Fix paths using prefixes. +labels.replace_filenames(prefix_map={ + "D:/data/sleap_projects": "/home/user/sleap_projects", + "C:/Users/sleaper/Desktop/test": "/home/user/sleap_projects", +}) + +labels.save("labels.v002.slp") +``` + +**See also:** [`Labels.replace_filenames`](model.md#sleap_io.Labels.replace_filenames) + + +### Save labels with embedded images + +```py +import sleap_io as sio + +# Load source labels. +labels = sio.load_file("labels.v001.slp") + +# Save with embedded images for frames with user labeled data and suggested frames. +labels.save("labels.v001.pkg.slp", embed="user+suggestions") +``` + +**See also:** [`Labels.save`](model.md#sleap_io.Labels.save) + + +### Make training/validation/test splits + +```py +import sleap_io as sio + +# Load source labels. +labels = sio.load_file("labels.v001.slp") + +# Make splits and export with embedded images. +labels.make_training_splits(n_train=0.8, n_val=0.1, n_test=0.1, save_dir="split1", seed=42) + +# Splits will be saved as self-contained SLP package files with images and labels. +labels_train = sio.load_file("split1/train.pkg.slp") +labels_val = sio.load_file("split1/val.pkg.slp") +labels_test = sio.load_file("split1/test.pkg.slp") +``` + +**See also:** [`Labels.make_training_splits`](model.md#sleap_io.Labels.make_training_splits) + + ## Support For technical inquiries specific to this package, please [open an Issue](https://github.com/talmolab/sleap-io/issues) with a description of your problem or request. diff --git a/mkdocs.yml b/mkdocs.yml index 06daa342..e2b1fc60 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -144,13 +144,13 @@ extra_css: - css/mkdocstrings.css copyright: > - Copyright © 2024 - 2024 Talmo Lab – + Copyright © 2022 - 2024 Talmo Lab – Change cookie settings nav: - Overview: index.md - Changelog: https://github.com/talmolab/sleap-io/releases - Core API: - - Data model: model.md - - Data formats: formats.md + - Model: model.md + - Formats: formats.md - Full API: reference/ \ No newline at end of file diff --git a/sleap_io/io/jabs.py b/sleap_io/io/jabs.py index fd687f72..518acf04 100644 --- a/sleap_io/io/jabs.py +++ b/sleap_io/io/jabs.py @@ -169,7 +169,9 @@ def read_labels( instances.append(new_instance) frame_label = LabeledFrame(video, frame_idx, instances) frames.append(frame_label) - return Labels(frames) + labels = Labels(frames) + labels.provenance["filename"] = labels_path + return labels def make_simple_skeleton(name: str, num_points: int) -> Skeleton: diff --git a/sleap_io/io/labelstudio.py b/sleap_io/io/labelstudio.py index 2bb1f737..c9ae3b14 100644 --- a/sleap_io/io/labelstudio.py +++ b/sleap_io/io/labelstudio.py @@ -43,7 +43,9 @@ def read_labels( else: assert isinstance(skeleton, Skeleton) - return parse_tasks(tasks, skeleton) + labels = parse_tasks(tasks, skeleton) + labels.provenance["filename"] = labels_path + return labels def infer_nodes(tasks: List[Dict]) -> Skeleton: diff --git a/sleap_io/io/main.py b/sleap_io/io/main.py index 49eebd6b..7fd702f7 100644 --- a/sleap_io/io/main.py +++ b/sleap_io/io/main.py @@ -20,17 +20,29 @@ def load_slp(filename: str) -> Labels: def save_slp( - labels: Labels, filename: str, embed: str | list[tuple[Video, int]] | None = None + labels: Labels, + filename: str, + embed: bool | str | list[tuple[Video, int]] | None = None, ): """Save a SLEAP dataset to a `.slp` file. Args: labels: A SLEAP `Labels` object (see `load_slp`). filename: Path to save labels to ending with `.slp`. - embed: One of `"user"`, `"suggestions"`, `"user+suggestions"`, `"source"` or - list of tuples of `(video, frame_idx)` specifying the frames to embed. If - `"source"` is specified, no images will be embedded and the source video + embed: Frames to embed in the saved labels file. One of `None`, `True`, + `"all"`, `"user"`, `"suggestions"`, `"user+suggestions"`, `"source"` or list + of tuples of `(video, frame_idx)`. + + If `None` is specified (the default) and the labels contains embedded + frames, those embedded frames will be re-saved to the new file. + + If `True` or `"all"`, all labeled frames and suggested frames will be + embedded. + + If `"source"` is specified, no images will be embedded and the source video will be restored if available. + + This argument is only valid for the SLP backend. """ return slp.write_labels(filename, labels, embed=embed) @@ -149,7 +161,7 @@ def load_file( A `Labels` or `Video` object. """ if isinstance(filename, Path): - filename = str(filename) + filename = filename.as_posix() if format is None: if filename.endswith(".slp"): diff --git a/sleap_io/io/nwb.py b/sleap_io/io/nwb.py index 2af64921..9314294a 100644 --- a/sleap_io/io/nwb.py +++ b/sleap_io/io/nwb.py @@ -140,6 +140,7 @@ def read_nwb(path: str) -> Labels: LabeledFrame(video=video, frame_idx=frame_idx, instances=insts) ) labels = Labels(lfs) + labels.provenance["filename"] = path return labels diff --git a/sleap_io/io/slp.py b/sleap_io/io/slp.py index 4a1f5f85..562b7076 100644 --- a/sleap_io/io/slp.py +++ b/sleap_io/io/slp.py @@ -329,6 +329,9 @@ def embed_frames( to_embed_by_video[video] = [] to_embed_by_video[video].append(frame_idx) + for video in to_embed_by_video: + to_embed_by_video[video] = np.unique(to_embed_by_video[video]) + replaced_videos = {} for video, frame_inds in to_embed_by_video.items(): video_ind = labels.videos.index(video) @@ -348,18 +351,30 @@ def embed_frames( def embed_videos( - labels_path: str, labels: Labels, embed: str | list[tuple[Video, int]] + labels_path: str, labels: Labels, embed: bool | str | list[tuple[Video, int]] ): """Embed videos in a SLEAP labels file. Args: labels_path: A string path to the SLEAP labels file to save. labels: A `Labels` object to save. - embed: One of `"user"`, `"suggestions"`, `"user+suggestions"`, `"source"` or - list of tuples of `(video, frame_idx)` specifying the frames to embed. If - `"source"` is specified, no images will be embedded and the source video + embed: Frames to embed in the saved labels file. One of `None`, `True`, + `"all"`, `"user"`, `"suggestions"`, `"user+suggestions"`, `"source"` or list + of tuples of `(video, frame_idx)`. + + If `None` is specified (the default) and the labels contains embedded + frames, those embedded frames will be re-saved to the new file. + + If `True` or `"all"`, all labeled frames and suggested frames will be + embedded. + + If `"source"` is specified, no images will be embedded and the source video will be restored if available. + + This argument is only valid for the SLP backend. """ + if embed is True: + embed = "all" if embed == "user": embed = [(lf.video, lf.frame_idx) for lf in labels.user_labeled_frames] elif embed == "suggestions": @@ -367,6 +382,9 @@ def embed_videos( elif embed == "user+suggestions": embed = [(lf.video, lf.frame_idx) for lf in labels.user_labeled_frames] embed += [(sf.video, sf.frame_idx) for sf in labels.suggestions] + elif embed == "all": + embed = [(lf.video, lf.frame_idx) for lf in labels] + embed += [(sf.video, sf.frame_idx) for sf in labels.suggestions] elif embed == "source": embed = [] elif isinstance(embed, list): @@ -711,6 +729,13 @@ def write_metadata(labels_path: str, labels: Labels): "negative_anchors": {}, "provenance": labels.provenance, } + + # Custom encoding. + for k in md["provenance"]: + if isinstance(md["provenance"][k], Path): + # Path -> str + md["provenance"][k] = md["provenance"][k].as_posix() + with h5py.File(labels_path, "a") as f: grp = f.require_group("metadata") grp.attrs["format_id"] = 1.2 @@ -942,7 +967,9 @@ def write_lfs(labels_path: str, labels: Labels): # Link instances based on from_predicted field. for instance_id, from_predicted in to_link: - instances[instance_id][5] = inst_to_id[id(from_predicted)] + # Source instance may be missing if predictions were removed from the labels, in + # which case, remove the link. + instances[instance_id][5] = inst_to_id.get(id(from_predicted), -1) # Create structured arrays. points = np.array([tuple(x) for x in points], dtype=point_dtype) @@ -1013,23 +1040,35 @@ def read_labels(labels_path: str) -> Labels: suggestions=suggestions, provenance=provenance, ) + labels.provenance["filename"] = labels_path return labels def write_labels( - labels_path: str, labels: Labels, embed: str | list[tuple[Video, int]] | None = None + labels_path: str, + labels: Labels, + embed: bool | str | list[tuple[Video, int]] | None = None, ): """Write a SLEAP labels file. Args: labels_path: A string path to the SLEAP labels file to save. labels: A `Labels` object to save. - embed: One of `"user"`, `"suggestions"`, `"user+suggestions"`, `"source"`, - `None` or list of tuples of `(video, frame_idx)` specifying the frames to - embed. If `"source"` is specified, no images will be embedded and the source - video will be restored if available. If `None` is specified (the default), - existing embedded images will be re-embedded. + embed: Frames to embed in the saved labels file. One of `None`, `True`, + `"all"`, `"user"`, `"suggestions"`, `"user+suggestions"`, `"source"` or list + of tuples of `(video, frame_idx)`. + + If `None` is specified (the default) and the labels contains embedded + frames, those embedded frames will be re-saved to the new file. + + If `True` or `"all"`, all labeled frames and suggested frames will be + embedded. + + If `"source"` is specified, no images will be embedded and the source video + will be restored if available. + + This argument is only valid for the SLP backend. """ if Path(labels_path).exists(): Path(labels_path).unlink() diff --git a/sleap_io/io/video.py b/sleap_io/io/video.py index 74c3d316..dae8e5e8 100644 --- a/sleap_io/io/video.py +++ b/sleap_io/io/video.py @@ -538,7 +538,7 @@ def find_embedded(name, obj): self.image_format = ds.attrs["format"] if "frame_numbers" in ds.parent: - frame_numbers = ds.parent["frame_numbers"][:] + frame_numbers = ds.parent["frame_numbers"][:].astype(int) self.frame_map = {frame: idx for idx, frame in enumerate(frame_numbers)} self.source_inds = frame_numbers diff --git a/sleap_io/model/labels.py b/sleap_io/model/labels.py index 6e6ae3e4..67dea5cf 100644 --- a/sleap_io/model/labels.py +++ b/sleap_io/model/labels.py @@ -25,6 +25,7 @@ import numpy as np from pathlib import Path from sleap_io.model.skeleton import Skeleton +from copy import deepcopy @define @@ -341,7 +342,7 @@ def save( self, filename: str, format: Optional[str] = None, - embed: str | list[tuple[Video, int]] | None = None, + embed: bool | str | list[tuple[Video, int]] | None = None, **kwargs, ): """Save labels to file in specified format. @@ -349,13 +350,22 @@ def save( Args: filename: Path to save labels to. format: The format to save the labels in. If `None`, the format will be - inferred from the file extension. Available formats are "slp", "nwb", - "labelstudio", and "jabs". - embed: One of `"user"`, `"suggestions"`, `"user+suggestions"`, `"source"` or - list of tuples of `(video, frame_idx)` specifying the frames to embed. + inferred from the file extension. Available formats are `"slp"`, + `"nwb"`, `"labelstudio"`, and `"jabs"`. + embed: Frames to embed in the saved labels file. One of `None`, `True`, + `"all"`, `"user"`, `"suggestions"`, `"user+suggestions"`, `"source"` or + list of tuples of `(video, frame_idx)`. + + If `None` is specified (the default) and the labels contains embedded + frames, those embedded frames will be re-saved to the new file. + + If `True` or `"all"`, all labeled frames and suggested frames will be + embedded. + If `"source"` is specified, no images will be embedded and the source - video will be restored if available. This argument is only valid for the - SLP backend. + video will be restored if available. + + This argument is only valid for the SLP backend. """ from sleap_io import save_file @@ -551,3 +561,128 @@ def replace_filenames( video.replace_filename( new_prefix / fn.relative_to(old_prefix) ) + + def split(self, n: int | float, seed: int | None = None) -> tuple[Labels, Labels]: + """Separate the labels into random splits. + + Args: + n: Size of the first split. If integer >= 1, assumes that this is the number + of labeled frames in the first split. If < 1.0, this will be treated as + a fraction of the total labeled frames. + seed: Optional integer seed to use for reproducibility. + + Returns: + A tuple of `split1, split2`. + + If an integer was specified, `len(split1) == n`. + + If a fraction was specified, `len(split1) == int(n * len(labels))`. + + The second split contains the remainder, i.e., + `len(split2) == len(labels) - len(split1)`. + + If there are too few frames, a minimum of 1 frame will be kept in the second + split. + + If there is exactly 1 labeled frame in the labels, the same frame will be + assigned to both splits. + """ + n0 = len(self) + if n0 == 0: + return self, self + n1 = n + if n < 1.0: + n1 = max(int(n0 * float(n)), 1) + n2 = max(n0 - n1, 1) + n1, n2 = int(n1), int(n2) + + rng = np.random.default_rng(seed=seed) + inds1 = rng.choice(n0, size=(n1,), replace=False) + + if n0 == 1: + inds2 = np.array([0]) + else: + inds2 = np.setdiff1d(np.arange(n0), inds1) + + split1, split2 = self[inds1], self[inds2] + split1, split2 = deepcopy(split1), deepcopy(split2) + split1, split2 = Labels(split1), Labels(split2) + + split1.provenance = self.provenance + split2.provenance = self.provenance + split1.provenance["source_labels"] = self.provenance.get("filename", None) + split2.provenance["source_labels"] = self.provenance.get("filename", None) + + return split1, split2 + + def make_training_splits( + self, + n_train: int | float, + n_val: int | float | None = None, + n_test: int | float | None = None, + save_dir: str | Path | None = None, + seed: int | None = None, + ) -> tuple[Labels, Labels] | tuple[Labels, Labels, Labels]: + """Make splits for training with embedded images. + + Args: + n_train: Size of the training split as integer or fraction. + n_val: Size of the validation split as integer or fraction. If `None`, + this will be inferred based on the values of `n_train` and `n_test`. If + `n_test` is `None`, this will be the remainder of the data after the + training split. + n_test: Size of the testing split as integer or fraction. If `None`, the + test split will not be saved. + save_dir: If specified, save splits to SLP files with embedded images. + seed: Optional integer seed to use for reproducibility. + + Returns: + A tuple of `labels_train, labels_val` or + `labels_train, labels_val, labels_test` if `n_test` was specified. + + Notes: + Predictions and suggestions will be removed before saving, leaving only + frames with user labeled data (the source labels are not affected). + + Frames with user labeled data will be embedded in the resulting files. + + If `save_dir` is specified, this will save the randomly sampled splits to: + + - `{save_dir}/train.pkg.slp` + - `{save_dir}/val.pkg.slp` + - `{save_dir}/test.pkg.slp` (if `n_test` is specified) + + See also: `Labels.split` + """ + # Clean up labels. + labels = deepcopy(self) + labels.remove_predictions() + labels.suggestions = [] + labels.clean() + + # Make splits. + labels_train, labels_rest = labels.split(n_train, seed=seed) + if n_test is not None: + if n_test < 1: + n_test = (n_test * len(labels)) / len(labels_rest) + labels_test, labels_rest = labels_rest.split(n=n_test, seed=seed) + if n_val is not None: + if n_val < 1: + n_val = (n_val * len(labels)) / len(labels_rest) + labels_val, _ = labels_rest.split(n=n_val, seed=seed) + else: + labels_val = labels_rest + + # Save. + if save_dir is not None: + save_dir = Path(save_dir) + save_dir.mkdir(exist_ok=True, parents=True) + + labels_train.save(save_dir / "train.pkg.slp", embed="user") + labels_val.save(save_dir / "val.pkg.slp", embed="user") + labels_test.save(save_dir / "test.pkg.slp", embed="user") + + if n_test is None: + return labels_train, labels_val + else: + return labels_train, labels_val, labels_test diff --git a/sleap_io/model/video.py b/sleap_io/model/video.py index a2644cd6..049963ba 100644 --- a/sleap_io/model/video.py +++ b/sleap_io/model/video.py @@ -50,11 +50,6 @@ class Video: EXTS = MediaVideo.EXTS + HDF5Video.EXTS + ImageVideo.EXTS - def __attrs_post_init__(self): - """Post init syntactic sugar.""" - if self.backend is None and self.exists(): - self.open() - def __attrs_post_init__(self): """Post init syntactic sugar.""" if self.backend is None and self.exists(): @@ -273,7 +268,7 @@ def replace_filename( the new filename does not exist, no error is raised. """ if isinstance(new_filename, Path): - new_filename = str(new_filename) + new_filename = new_filename.as_posix() if isinstance(new_filename, list): new_filename = [ diff --git a/sleap_io/version.py b/sleap_io/version.py index 1d04a962..c626e6a9 100644 --- a/sleap_io/version.py +++ b/sleap_io/version.py @@ -2,4 +2,4 @@ # Define package version. # This is read dynamically by setuptools in pyproject.toml to determine the release version. -__version__ = "0.1.4" +__version__ = "0.1.5" diff --git a/tests/io/test_slp.py b/tests/io/test_slp.py index 7d229bc4..cd37b9a1 100644 --- a/tests/io/test_slp.py +++ b/tests/io/test_slp.py @@ -283,7 +283,9 @@ def test_pkg_roundtrip(tmpdir, slp_minimal_pkg): ) -@pytest.mark.parametrize("to_embed", ["user", "suggestions", "user+suggestions"]) +@pytest.mark.parametrize( + "to_embed", [True, "all", "user", "suggestions", "user+suggestions"] +) def test_embed(tmpdir, slp_real_data, to_embed): base_labels = read_labels(slp_real_data) assert type(base_labels.video.backend) == MediaVideo @@ -306,8 +308,21 @@ def test_embed(tmpdir, slp_real_data, to_embed): Path(labels.video.source_video.filename).as_posix() == "tests/data/videos/centered_pair_low_quality.mp4" ) - if to_embed == "user": - assert labels.video.backend.embedded_frame_inds == [0, 990, 440, 220, 770] + if to_embed == "all" or to_embed is True: + assert labels.video.backend.embedded_frame_inds == [ + 0, + 110, + 220, + 330, + 440, + 550, + 660, + 770, + 880, + 990, + ] + elif to_embed == "user": + assert labels.video.backend.embedded_frame_inds == [0, 220, 440, 770, 990] elif to_embed == "suggestions": assert len(labels.video.backend.embedded_frame_inds) == 10 elif to_embed == "suggestions+user": @@ -320,7 +335,7 @@ def test_embed_two_rounds(tmpdir, slp_real_data): write_labels(labels_path, base_labels, embed="user") labels = read_labels(labels_path) - assert labels.video.backend.embedded_frame_inds == [0, 990, 440, 220, 770] + assert labels.video.backend.embedded_frame_inds == [0, 220, 440, 770, 990] labels2_path = str(tmpdir / "labels2.pkg.slp") write_labels(labels2_path, labels) @@ -329,7 +344,7 @@ def test_embed_two_rounds(tmpdir, slp_real_data): Path(labels2.video.source_video.filename).as_posix() == "tests/data/videos/centered_pair_low_quality.mp4" ) - assert labels2.video.backend.embedded_frame_inds == [0, 990, 440, 220, 770] + assert labels2.video.backend.embedded_frame_inds == [0, 220, 440, 770, 990] labels3_path = str(tmpdir / "labels3.slp") write_labels(labels3_path, labels, embed="source") diff --git a/tests/model/test_labels.py b/tests/model/test_labels.py index cc783a0d..067a1d30 100644 --- a/tests/model/test_labels.py +++ b/tests/model/test_labels.py @@ -416,3 +416,174 @@ def test_replace_filenames(): labels.replace_filenames(prefix_map={"train/": "test/"}) assert labels.video.filename == ["test/imgs/img0.png", "test/imgs/img1.png"] + + +def test_split(slp_real_data, tmp_path): + # n = 0 + labels = Labels() + split1, split2 = labels.split(0.5) + assert len(split1) == len(split2) == 0 + + # n = 1 + labels.append(LabeledFrame(video=Video("test.mp4"), frame_idx=0)) + split1, split2 = labels.split(0.5) + assert len(split1) == len(split2) == 1 + assert split1[0].frame_idx == 0 + assert split2[0].frame_idx == 0 + + split1, split2 = labels.split(0.999) + assert len(split1) == len(split2) == 1 + assert split1[0].frame_idx == 0 + assert split2[0].frame_idx == 0 + + split1, split2 = labels.split(n=1) + assert len(split1) == len(split2) == 1 + assert split1[0].frame_idx == 0 + assert split2[0].frame_idx == 0 + + # Real data + labels = load_slp(slp_real_data) + assert len(labels) == 10 + + split1, split2 = labels.split(n=0.6) + assert len(split1) == 6 + assert len(split2) == 4 + + # Rounding errors + split1, split2 = labels.split(n=0.001) + assert len(split1) == 1 + assert len(split2) == 9 + + split1, split2 = labels.split(n=0.999) + assert len(split1) == 9 + assert len(split2) == 1 + + # Integer + split1, split2 = labels.split(n=8) + assert len(split1) == 8 + assert len(split2) == 2 + + # Serialization round trip + split1.save(tmp_path / "split1.slp") + split1_ = load_slp(tmp_path / "split1.slp") + assert len(split1) == len(split1_) + assert split1.video.filename == "tests/data/videos/centered_pair_low_quality.mp4" + assert split1_.video.filename == "tests/data/videos/centered_pair_low_quality.mp4" + + split2.save(tmp_path / "split2.slp") + split2_ = load_slp(tmp_path / "split2.slp") + assert len(split2) == len(split2_) + assert split2.video.filename == "tests/data/videos/centered_pair_low_quality.mp4" + assert split2_.video.filename == "tests/data/videos/centered_pair_low_quality.mp4" + + # Serialization round trip with embedded data + labels = load_slp(slp_real_data) + labels.save(tmp_path / "test.pkg.slp", embed=True) + pkg = load_slp(tmp_path / "test.pkg.slp") + + split1, split2 = pkg.split(n=0.8) + assert len(split1) == 8 + assert len(split2) == 2 + assert split1.video.filename == (tmp_path / "test.pkg.slp").as_posix() + assert split2.video.filename == (tmp_path / "test.pkg.slp").as_posix() + assert ( + split1.video.source_video.filename + == "tests/data/videos/centered_pair_low_quality.mp4" + ) + assert ( + split2.video.source_video.filename + == "tests/data/videos/centered_pair_low_quality.mp4" + ) + + split1.save(tmp_path / "split1.pkg.slp", embed=True) + split2.save(tmp_path / "split2.pkg.slp", embed=True) + assert pkg.video.filename == (tmp_path / "test.pkg.slp").as_posix() + assert ( + Path(split1.video.filename).as_posix() + == (tmp_path / "split1.pkg.slp").as_posix() + ) + assert ( + Path(split2.video.filename).as_posix() + == (tmp_path / "split2.pkg.slp").as_posix() + ) + assert ( + split1.video.source_video.filename + == "tests/data/videos/centered_pair_low_quality.mp4" + ) + assert ( + split2.video.source_video.filename + == "tests/data/videos/centered_pair_low_quality.mp4" + ) + + split1_ = load_slp(tmp_path / "split1.pkg.slp") + split2_ = load_slp(tmp_path / "split2.pkg.slp") + assert len(split1_) == 8 + assert len(split2_) == 2 + assert ( + Path(split1_.video.filename).as_posix() + == (tmp_path / "split1.pkg.slp").as_posix() + ) + assert ( + Path(split2_.video.filename).as_posix() + == (tmp_path / "split2.pkg.slp").as_posix() + ) + assert ( + split1_.video.source_video.filename + == "tests/data/videos/centered_pair_low_quality.mp4" + ) + assert ( + split2_.video.source_video.filename + == "tests/data/videos/centered_pair_low_quality.mp4" + ) + + +def test_make_training_splits(slp_real_data, tmp_path): + labels = load_slp(slp_real_data) + assert len(labels.user_labeled_frames) == 5 + + train, val = labels.make_training_splits(0.8) + assert len(train) == 4 + assert len(val) == 1 + + train, val = labels.make_training_splits(3) + assert len(train) == 3 + assert len(val) == 2 + + train, val = labels.make_training_splits(0.8, 0.2) + assert len(train) == 4 + assert len(val) == 1 + + train, val, test = labels.make_training_splits(0.8, 0.1, 0.1) + assert len(train) == 4 + assert len(val) == 1 + assert len(test) == 1 + + train, val, test = labels.make_training_splits(n_train=0.6, n_test=1) + assert len(train) == 3 + assert len(val) == 1 + assert len(test) == 1 + + train, val, test = labels.make_training_splits(n_train=1, n_val=1, n_test=1) + assert len(train) == 1 + assert len(val) == 1 + assert len(test) == 1 + + +def test_make_training_splits_save(slp_real_data, tmp_path): + labels = load_slp(slp_real_data) + + train, val, test = labels.make_training_splits(0.6, 0.2, 0.2, save_dir=tmp_path) + + train_, val_, test_ = ( + load_slp(tmp_path / "train.pkg.slp"), + load_slp(tmp_path / "val.pkg.slp"), + load_slp(tmp_path / "test.pkg.slp"), + ) + + assert len(train_) == len(train) + assert len(val_) == len(val) + assert len(test_) == len(test) + + assert train_.provenance["source_labels"] == slp_real_data + assert val_.provenance["source_labels"] == slp_real_data + assert test_.provenance["source_labels"] == slp_real_data