-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for the Kumar Lab's JABS format #63
Conversation
Hi @SkepticRaven! Thanks for the initial PR -- this is looking great!! As far as static objects and other attributes (px_to_cm): These are useful additions and I believe that we could add them to the core data models. For the sake of keeping this PR manageable, I propose that you go ahead and finish implementing the basic support for JABS formatted files and we'll do a second PR that adds new data model features. The other formats that don't support those can just ignore them. |
Codecov Report
@@ Coverage Diff @@
## main #63 +/- ##
==========================================
- Coverage 94.68% 94.45% -0.23%
==========================================
Files 12 13 +1
Lines 1241 1442 +201
==========================================
+ Hits 1175 1362 +187
- Misses 66 80 +14
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
Hey @talmo. I agree that attributes (eg px_to_cm) don't yet have a place -- can leave it out of this PR and add it when SLEAP has a place for it! As for our training data, I think I'll leave that out of this pull request. Our single mouse training data is in a h5 format (originating from CVAT). Our multi-mouse data has been converted between multiple formats (including SLEAP and Label-Studio, both already supported here) and we haven't released it yet. |
…all Instances have skeletons)
WalkthroughThis pull request introduces support for JABS pose file format in the SLEAP codebase. It adds new functions for reading and writing JABS files, converting labels to JABS format, and generating default skeletons and symmetries. The changes also include test fixtures and test cases to validate the functionality of these additions. Changes
TipsChat with CodeRabbit Bot (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Status
Actionable comments generated: 9
Files selected for processing (7)
- sleap_io/init.py (1 hunks)
- sleap_io/io/jabs.py (1 hunks)
- sleap_io/io/main.py (2 hunks)
- tests/conftest.py (1 hunks)
- tests/fixtures/jabs.py (1 hunks)
- tests/io/test_jabs.py (1 hunks)
- tests/io/test_main.py (2 hunks)
Files skipped from review due to trivial changes (3)
- sleap_io/init.py
- tests/conftest.py
- tests/io/test_jabs.py
Additional comments (Suppressed): 4
tests/fixtures/jabs.py (1)
- 1-30: The fixtures
jabs_real_data_v2
andjabs_real_data_v5
are well defined and provide clear documentation about the data they return. The paths returned by these fixtures seem to be hardcoded. Ensure that these files exist in the repository and are placed in the correct directory.sleap_io/io/main.py (2)
79-105: The new functions
load_jabs
andsave_jabs
have been added to handle JABS file format. The function signatures and documentation look good. However, please ensure that theskeleton
parameter inload_jabs
andpose_version
parameter insave_jabs
are handled correctly in all calls to these functions throughout the codebase.4-5: The import statement has been updated to include the
jabs
module fromsleap_io.io
. This is necessary for the newly added JABS support.tests/io/test_main.py (1)
- 6-11: The import statements have been updated to include the new
load_jabs
andsave_jabs
functions. Ensure that these functions are correctly implemented in the imported module.
tests/io/test_main.py
Outdated
def test_jabs(tmp_path, jabs_real_data_v2, jabs_real_data_v5): | ||
labels_single = load_jabs(jabs_real_data_v2) | ||
assert isinstance(labels_single, Labels) | ||
save_jabs(labels_single, 2, tmp_path) | ||
|
||
labels_multi = load_jabs(jabs_real_data_v5) | ||
assert isinstance(labels_multi, Labels) | ||
save_jabs(labels_multi, 3, tmp_path) | ||
save_jabs(labels_multi, 4, tmp_path) | ||
save_jabs(labels_multi, 5, tmp_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a new test function for the JABS file format support. It tests both loading and saving of JABS files. However, there are no assertions to verify the correctness of the saved files. Consider adding checks to ensure that the saved files match the expected output.
save_jabs(labels_single, 2, tmp_path)
+ loaded_labels = load_jabs(tmp_path / "labels_single.jabs")
+ assert labels_single == loaded_labels
save_jabs(labels_multi, 3, tmp_path)
+ loaded_labels = load_jabs(tmp_path / "labels_multi_v3.jabs")
+ assert labels_multi == loaded_labels
save_jabs(labels_multi, 4, tmp_path)
+ loaded_labels = load_jabs(tmp_path / "labels_multi_v4.jabs")
+ assert labels_multi == loaded_labels
save_jabs(labels_multi, 5, tmp_path)
+ loaded_labels = load_jabs(tmp_path / "labels_multi_v5.jabs")
+ assert labels_multi == loaded_labels
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
v2 check could work, but the remainder do not guarantee sorting of data. On first load/save, the data is equivalent, but not identical. A second round of loading + saving might allow this check to work.
v3 and v4 checks definitely won't work, since data is intentionally discarded.
Also, the filenames are generated based on videos in the annotation and are not labels_*.jabs
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SkepticRaven, thank you for your feedback. I understand the challenges with the v3 and v4 checks due to data being intentionally discarded. In this case, we could modify our tests to only check for the presence of essential data rather than exact equality.
As for the filenames, I apologize for the oversight. We should indeed use the actual filenames generated based on the videos in the annotation. Here's a revised version of the test:
def test_jabs(tmp_path, jabs_real_data_v2, jabs_real_data_v5):
labels_single = load_jabs(jabs_real_data_v2)
assert isinstance(labels_single, Labels)
save_jabs(labels_single, 2, tmp_path)
+ loaded_labels = load_jabs(tmp_path / "actual_filename_based_on_video.jabs")
+ assert labels_single.essential_data == loaded_labels.essential_data
labels_multi = load_jabs(jabs_real_data_v5)
assert isinstance(labels_multi, Labels)
save_jabs(labels_multi, 3, tmp_path)
+ loaded_labels = load_jabs(tmp_path / "actual_filename_based_on_video_v3.jabs")
+ assert labels_multi.essential_data == loaded_labels.essential_data
save_jabs(labels_multi, 4, tmp_path)
+ loaded_labels = load_jabs(tmp_path / "actual_filename_based_on_video_v4.jabs")
+ assert labels_multi.essential_data == loaded_labels.essential_data
save_jabs(labels_multi, 5, tmp_path)
+ loaded_labels = load_jabs(tmp_path / "actual_filename_based_on_video_v5.jabs")
+ assert labels_multi.essential_data == loaded_labels.essential_data
Please replace actual_filename_based_on_video
with the correct filename based on the video in the annotation. The essential_data
attribute should be replaced with the actual attribute or method that checks for the presence of essential data in the Labels
object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As an alternative, can we just minimally test for the len(labels)
being equivalent before and after the roundtrip?
Full equality testing will be a bit onerous currently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like humans are talking. I'll be quiet for now. If you need me again, just tag @coderabbitai in a new comment.
sleap_io/io/jabs.py
Outdated
frames: List[LabeledFrame] = [] | ||
# Video name is the pose file minus the suffix | ||
video_name = re.sub(r"(_pose_est_v[2-6])?\.h5", ".avi", labels_path) | ||
if not skeleton: | ||
skeleton = JABS_DEFAULT_SKELETON | ||
tracks = {} | ||
|
||
with h5py.File(labels_path, "r") as pose_file: | ||
num_frames = pose_file["poseest/points"].shape[0] | ||
try: | ||
pose_version = pose_file["poseest"].attrs["version"][0] | ||
except Exception: | ||
pose_version = 2 | ||
tracks[1] = Track("1") | ||
data_shape = pose_file["poseest/points"].shape | ||
assert ( | ||
len(data_shape) == 3 | ||
), f"Pose version not present and shape does not match single mouse: shape of {data_shape} for {labels_path}" | ||
# Change field name for newer pose formats | ||
if pose_version == 3: | ||
id_key = "instance_track_id" | ||
elif pose_version > 3: | ||
id_key = "instance_embed_id" | ||
max_ids = pose_file["poseest/points"].shape[1] | ||
|
||
for frame_idx in range(num_frames): | ||
instances = [] | ||
pose_data = pose_file["poseest/points"][frame_idx, ...] | ||
# JABS stores y,x for poses | ||
pose_data = np.flip(pose_data, axis=-1) | ||
pose_conf = pose_file["poseest/confidence"][frame_idx, ...] | ||
# single animal case | ||
if pose_version == 2: | ||
new_instance = prediction_to_instance( | ||
pose_data, pose_conf, skeleton, tracks[1] | ||
) | ||
instances.append(new_instance) | ||
# multi-animal case | ||
if pose_version > 2: | ||
pose_ids = pose_file["poseest/" + id_key][frame_idx, ...] | ||
# pose_v3 uses another field to describe the number of valid poses | ||
if pose_version == 3: | ||
max_ids = pose_file["poseest/instance_count"][frame_idx] | ||
for cur_id in range(max_ids): | ||
# v4+ uses reserved values for invalid/unused poses | ||
# Note: ignores 'poseest/id_mask' to keep predictions that were not assigned an id | ||
if pose_version > 3 and pose_ids[cur_id] <= 0: | ||
continue | ||
if pose_ids[cur_id] not in tracks.keys(): | ||
tracks[pose_ids[cur_id]] = Track(str(pose_ids[cur_id])) | ||
new_instance = prediction_to_instance( | ||
pose_data[cur_id], | ||
pose_conf[cur_id], | ||
skeleton, | ||
tracks[pose_ids[cur_id]], | ||
) | ||
if new_instance: | ||
instances.append(new_instance) | ||
# Static objects | ||
if ( | ||
frame_idx == 0 | ||
and pose_version >= 5 | ||
and "static_objects" in pose_file.keys() | ||
): | ||
present_objects = pose_file["static_objects"].keys() | ||
for cur_object in present_objects: | ||
object_keypoints = pose_file["static_objects/" + cur_object][:] | ||
object_skeleton = make_simple_skeleton( | ||
cur_object, object_keypoints.shape[0] | ||
) | ||
new_instance = prediction_to_instance( | ||
object_keypoints, | ||
np.ones(object_keypoints.shape[:-1]), | ||
object_skeleton, | ||
) | ||
if new_instance: | ||
instances.append(new_instance) | ||
frame_label = LabeledFrame(Video(video_name), frame_idx, instances) | ||
frames.append(frame_label) | ||
return Labels(frames) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is quite long and does multiple things: it reads data from a file, processes it, and creates various objects. Consider breaking it down into smaller functions for better readability and maintainability. For example, you could have separate functions for reading the file, processing the data, and creating the objects.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@coderabbitai: ignore
"creating the objects" is already separated as prediction_to_instance
.
Reading and processing the data is coupled to reduce RAM footprint (only have 2 copies in the frame loop instead of video loop).
The checks and keys for reading could move to another function, but I don't have the time to make that refactor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SkepticRaven: I understand your concerns about memory usage and time constraints. However, it's important to note that maintainability and readability of the code are also crucial aspects of software development.
While the prediction_to_instance
function does indeed separate the creation of objects, the read_labels
function still performs multiple tasks which could be further broken down. For instance, handling different pose versions and processing static objects could potentially be moved into their own functions.
Here's a quick example of how you might refactor the pose version handling:
def handle_pose_version(pose_file, num_frames):
try:
pose_version = pose_file["poseest"].attrs["version"][0]
except Exception:
pose_version = 2
tracks[1] = Track("1")
data_shape = pose_file["poseest/points"].shape
assert (
len(data_shape) == 3
), f"Pose version not present and shape does not match single mouse: shape of {data_shape} for {labels_path}"
return pose_version
# In read_labels function
pose_version = handle_pose_version(pose_file, num_frames)
This is just a suggestion and may not perfectly fit your needs, but hopefully it gives you an idea of how you might approach this. Even small steps towards cleaner, more modular code can have significant benefits in the long run.
sleap_io/io/jabs.py
Outdated
pose_version = 2 | ||
tracks[1] = Track("1") | ||
data_shape = pose_file["poseest/points"].shape | ||
assert ( | ||
len(data_shape) == 3 | ||
), f"Pose version not present and shape does not match single mouse: shape of {data_shape} for {labels_path}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no specific error handling here. If an exception occurs when trying to read the pose version, the code assumes that the pose version is 2. However, the exception could be due to other issues (e.g., file not found, permission denied). It would be better to catch specific exceptions and handle them appropriately.
sleap_io/io/jabs.py
Outdated
if confidence[i] > 0.001: | ||
points[cur_node] = Point( | ||
data[i, 0], | ||
data[i, 1], | ||
visible=True, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The confidence threshold (0.001) is hardcoded. Consider making this a parameter of the function so that it can be easily adjusted if needed.
sleap_io/io/jabs.py
Outdated
if os.path.exists(out_filename): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code checks if the output file already exists but doesn't do anything with this information. If overwriting is not desired, consider raising an error or warning if the file already exists.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM! I'll leave it open for another day or until I hear back in case you want to make more changes, then we'll merge.
Thanks!!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Status
Actionable comments generated: 6
Files selected for processing (2)
- sleap_io/io/jabs.py (1 hunks)
- tests/io/test_main.py (2 hunks)
Additional comments (Suppressed): 1
tests/io/test_main.py (1)
- 6-11: The import statements have been updated to include the new
load_jabs
andsave_jabs
functions. Ensure that these functions are correctly implemented in the imported module.
sleap_io/io/jabs.py
Outdated
if not os.access(labels_path, os.F_OK): | ||
raise PermissionError(f"{labels_path} cannot be accessed.") | ||
if not os.access(labels_path, os.R_OK): | ||
raise FileNotFoundError(f"{labels_path} doesn't exist.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The order of the checks for file existence and accessibility is incorrect. The check for file existence should be performed before checking for file access permissions. If the file does not exist, there's no point in checking its permissions.
- if not os.access(labels_path, os.F_OK):
- raise PermissionError(f"{labels_path} cannot be accessed.")
- if not os.access(labels_path, os.R_OK):
- raise FileNotFoundError(f"{labels_path} doesn't exist.")
+ if not os.path.exists(labels_path):
+ raise FileNotFoundError(f"{labels_path} doesn't exist.")
+ if not os.access(labels_path, os.R_OK):
+ raise PermissionError(f"{labels_path} cannot be accessed.")
assert ( | ||
len(skeleton.nodes) == data.shape[0] | ||
), f"Skeleton ({len(skeleton.nodes)}) does not match number of keypoints ({data.shape[0]})" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion error message seems to be swapped. It should indicate that the number of keypoints does not match the skeleton nodes.
- assert (
- len(skeleton.nodes) == data.shape[0]
- ), f"Skeleton ({len(skeleton.nodes)}) does not match number of keypoints ({data.shape[0]})"
+ assert (
+ len(skeleton.nodes) == data.shape[0]
+ ), f"Number of keypoints ({data.shape[0]}) does not match skeleton nodes ({len(skeleton.nodes)})"
warnings.warn( | ||
f"JABS format only supports 12 keypoints for mice. Skipping storage of instance on frame {label.frame_idx} with {len(instance.points)} keypoints." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This warning message could potentially flood the console if there are many instances with a different number of keypoints than expected. Consider aggregating these warnings and printing a summary instead of individual messages for each instance.
for video in labels.videos: | ||
converted_labels = convert_labels(labels, video) | ||
out_filename = ( | ||
os.path.splitext(video.filename)[0] + f"_pose_est_v{pose_version}.h5" | ||
) | ||
if root_folder: | ||
out_filename = os.path.join(root_folder, out_filename) | ||
os.makedirs(os.path.dirname(out_filename), exist_ok=True) | ||
if os.path.exists(out_filename): | ||
warnings.warn(f"Skipping {out_filename} because it already exists.") | ||
continue | ||
if pose_version == 2: | ||
write_jabs_v2(converted_labels, out_filename) | ||
elif pose_version == 3: | ||
write_jabs_v3(converted_labels, out_filename) | ||
elif pose_version == 4: | ||
write_jabs_v4(converted_labels, out_filename) | ||
elif pose_version == 5: | ||
write_jabs_v5(converted_labels, out_filename) | ||
else: | ||
raise NotImplementedError(f"Pose format {pose_version} not supported.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The write_labels
function currently supports pose versions 2 to 5. However, it does not handle cases where an unsupported version is passed. Adding a default case to handle unsupported versions would improve error handling.
+ else:
+ raise ValueError(f"Unsupported pose version: {pose_version}")
def write_jabs_v5(data: dict, filename: str): | ||
"""Write JABS pose file v5 data to file. | ||
|
||
Writes multi-mouse pose, longterm identity, and static object data. | ||
|
||
Args: | ||
data: Dictionary of JABS data generated from convert_labels | ||
filename: Filename to write data to | ||
""" | ||
# v5 extends v4 | ||
write_jabs_v4(data, filename) | ||
with h5py.File(filename, "a") as h5: | ||
pose_grp = h5.require_group("poseest") | ||
pose_grp.attrs.update({"version": [5, 0]}) | ||
if "static_objects" in data.keys(): | ||
object_grp = h5.require_group("static_objects") | ||
for object_key, object_keypoints in data["static_objects"].items(): | ||
object_grp.require_dataset( | ||
object_key, | ||
object_keypoints.shape, | ||
np.uint16, | ||
data=object_keypoints.astype(np.uint16), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The write_jabs_v2
, write_jabs_v3
, write_jabs_v4
, and write_jabs_v5
functions are very similar and contain duplicated code. Consider refactoring these functions to reduce code duplication and improve maintainability.
def test_jabs(tmp_path, jabs_real_data_v2, jabs_real_data_v5): | ||
labels_single = load_jabs(jabs_real_data_v2) | ||
assert isinstance(labels_single, Labels) | ||
save_jabs(labels_single, 2, tmp_path) | ||
labels_single_written = load_jabs(str(tmp_path / jabs_real_data_v2)) | ||
# Confidence field is not preserved, so just check number of labels | ||
assert len(labels_single) == len(labels_single_written) | ||
|
||
labels_multi = load_jabs(jabs_real_data_v5) | ||
assert isinstance(labels_multi, Labels) | ||
save_jabs(labels_multi, 3, tmp_path) | ||
save_jabs(labels_multi, 4, tmp_path) | ||
save_jabs(labels_multi, 5, tmp_path) | ||
labels_v5_written = load_jabs(str(tmp_path / jabs_real_data_v5)) | ||
# v5 contains all v4 and v3 data, so only need to check v5 | ||
# Confidence field and ordering of identities is not preserved, so just check number of labels | ||
assert len(labels_v5_written) == len(labels_multi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A new test function, test_jabs
, has been added to test the functionality of the load_jabs
and save_jabs
functions. The tests seem to be well-structured and cover different scenarios (single and multi-mouse poses). However, there is a potential issue with the way files are being saved and loaded for testing.
In lines 46 and 53-55, the save_jabs
function is called to save labels to a file at tmp_path
. However, the filename argument passed to save_jabs
is an integer (2, 3, 4, or 5), not a string representing a filename. This could lead to unexpected behavior or errors when trying to save the file.
In lines 47 and 56, the load_jabs
function is called to load labels from a file at str(tmp_path / jabs_real_data_v2)
or str(tmp_path / jabs_real_data_v5)
. However, these paths may not correspond to the files saved by save_jabs
if the filenames are not correctly specified.
Please verify that the save_jabs
function can accept an integer as a filename, and that the paths used in load_jabs
correspond to the files saved by save_jabs
.
- save_jabs(labels_single, 2, tmp_path)
+ save_jabs(labels_single, "file2", tmp_path)
- save_jabs(labels_multi, 3, tmp_path)
+ save_jabs(labels_multi, "file3", tmp_path)
- save_jabs(labels_multi, 4, tmp_path)
+ save_jabs(labels_multi, "file4", tmp_path)
- save_jabs(labels_multi, 5, tmp_path)
+ save_jabs(labels_multi, "file5", tmp_path)
- labels_single_written = load_jabs(str(tmp_path / jabs_real_data_v2))
+ labels_single_written = load_jabs(str(tmp_path / "file2"))
- labels_v5_written = load_jabs(str(tmp_path / jabs_real_data_v5))
+ labels_v5_written = load_jabs(str(tmp_path / "file5"))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Status
Actionable comments generated: 3
Files selected for processing (1)
- sleap_io/io/jabs.py (1 hunks)
Additional comments (Suppressed): 10
sleap_io/io/jabs.py (10)
1-66: The module starts by defining the default keypoints, edges, symmetries, and skeleton for JABS files. This is a good practice as it provides a clear understanding of the expected structure of the data and allows for easy modification if the structure changes in the future.
91-94: Good error handling for file access issues. The function checks if the file exists and if it can be accessed before attempting to read it.
172-184: The
make_simple_skeleton
function creates aSkeleton
with a requested number of nodes attached in a line. This could be useful for creating custom skeletons for different types of data.187-220: The
prediction_to_instance
function converts prediction data into anInstance
. It includes a check to ensure that the number of keypoints in the skeleton matches the number of keypoints in the data, which is a good practice for ensuring data consistency.222-237: The
get_max_ids_in_video
function determines the maximum number of identities that exist at the same time in a video. This could be useful for allocating resources or setting parameters based on the complexity of the video.350-386: The
tracklets_to_v3
function changes identity tracklets to the v3 format specifications. It includes a check to ensure that the input tracklet matrix is two-dimensional, which is a good practice for ensuring data consistency.389-415: The
write_jabs_v2
function writes JABS pose file v2 data to file. It checks that the data being written is single mouse data, which is a good practice for ensuring data consistency.417-466: The
write_jabs_v3
function writes JABS pose file v3 data to file. It uses thetracklets_to_v3
function to convert the identity tracklets to the v3 format before writing them to file.467-514: The
write_jabs_v4
function writes JABS pose file v4 data to file. It extends thewrite_jabs_v3
function and adds new fields for v4.516-538: The
write_jabs_v5
function writes JABS pose file v5 data to file. It extends thewrite_jabs_v4
function and adds support for static objects.
"""Read JABS style pose from a file and return a `Labels` object. | ||
|
||
TODO: Attributes are ignored, including px_to_cm field. | ||
TODO: Segmentation data ignored in v6, but will read in pose. | ||
TODO: Lixit static objects currently stored as n_lixit,2 (eg 1 object). Should be converted to multiple objects | ||
|
||
Args: | ||
labels_path: Path to the JABS pose file. | ||
skeleton: An optional `Skeleton` object. Defaults to JABS pose version 2-6. | ||
|
||
Returns: | ||
Parsed labels as a `Labels` instance. | ||
""" | ||
frames: List[LabeledFrame] = [] | ||
# Video name is the pose file minus the suffix | ||
video_name = re.sub(r"(_pose_est_v[2-6])?\.h5", ".avi", labels_path) | ||
if not skeleton: | ||
skeleton = JABS_DEFAULT_SKELETON | ||
tracks = {} | ||
|
||
if not os.access(labels_path, os.F_OK): | ||
raise FileNotFoundError(f"{labels_path} doesn't exist.") | ||
if not os.access(labels_path, os.R_OK): | ||
raise PermissionError(f"{labels_path} cannot be accessed.") | ||
|
||
with h5py.File(labels_path, "r") as pose_file: | ||
num_frames = pose_file["poseest/points"].shape[0] | ||
try: | ||
pose_version = pose_file["poseest"].attrs["version"][0] | ||
except (KeyError, IndexError): | ||
pose_version = 2 | ||
data_shape = pose_file["poseest/points"].shape | ||
assert ( | ||
len(data_shape) == 3 | ||
), f"Pose version not present and shape does not match single mouse: shape of {data_shape} for {labels_path}" | ||
if pose_version == 2: | ||
tracks[1] = Track("1") | ||
# Change field name for newer pose formats | ||
if pose_version == 3: | ||
id_key = "instance_track_id" | ||
elif pose_version > 3: | ||
id_key = "instance_embed_id" | ||
max_ids = pose_file["poseest/points"].shape[1] | ||
|
||
for frame_idx in range(num_frames): | ||
instances = [] | ||
pose_data = pose_file["poseest/points"][frame_idx, ...] | ||
# JABS stores y,x for poses | ||
pose_data = np.flip(pose_data, axis=-1) | ||
pose_conf = pose_file["poseest/confidence"][frame_idx, ...] | ||
# single animal case | ||
if pose_version == 2: | ||
new_instance = prediction_to_instance( | ||
pose_data, pose_conf, skeleton, tracks[1] | ||
) | ||
instances.append(new_instance) | ||
# multi-animal case | ||
if pose_version > 2: | ||
pose_ids = pose_file["poseest/" + id_key][frame_idx, ...] | ||
# pose_v3 uses another field to describe the number of valid poses | ||
if pose_version == 3: | ||
max_ids = pose_file["poseest/instance_count"][frame_idx] | ||
for cur_id in range(max_ids): | ||
# v4+ uses reserved values for invalid/unused poses | ||
# Note: ignores 'poseest/id_mask' to keep predictions that were not assigned an id | ||
if pose_version > 3 and pose_ids[cur_id] <= 0: | ||
continue | ||
if pose_ids[cur_id] not in tracks.keys(): | ||
tracks[pose_ids[cur_id]] = Track(str(pose_ids[cur_id])) | ||
new_instance = prediction_to_instance( | ||
pose_data[cur_id], | ||
pose_conf[cur_id], | ||
skeleton, | ||
tracks[pose_ids[cur_id]], | ||
) | ||
if new_instance: | ||
instances.append(new_instance) | ||
# Static objects | ||
if ( | ||
frame_idx == 0 | ||
and pose_version >= 5 | ||
and "static_objects" in pose_file.keys() | ||
): | ||
present_objects = pose_file["static_objects"].keys() | ||
for cur_object in present_objects: | ||
object_keypoints = pose_file["static_objects/" + cur_object][:] | ||
object_skeleton = make_simple_skeleton( | ||
cur_object, object_keypoints.shape[0] | ||
) | ||
new_instance = prediction_to_instance( | ||
object_keypoints, | ||
np.ones(object_keypoints.shape[:-1]), | ||
object_skeleton, | ||
) | ||
if new_instance: | ||
instances.append(new_instance) | ||
frame_label = LabeledFrame(Video(video_name), frame_idx, instances) | ||
frames.append(frame_label) | ||
return Labels(frames) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The read_labels
function reads JABS style pose from a file and returns a Labels
object. It handles both single animal and multi-animal cases, as well as static objects. However, there are several TODO comments indicating areas for future improvement, such as handling attributes, segmentation data, and lixit static objects. These should be addressed to ensure the function fully supports all aspects of the JABS format.
def convert_labels(all_labels: Labels, video: Video) -> dict: | ||
"""Convert a `Labels` object into JABS-formatted annotations. | ||
|
||
Args: | ||
all_labels: SLEAP `Labels` to be converted to JABS format. | ||
video: name of video to be converted | ||
|
||
Returns: | ||
Dictionary of JABS data of the `Labels` data. | ||
""" | ||
labels = all_labels.find(video=video) | ||
|
||
# Determine shape of output | ||
# Low estimate of last frame labeled | ||
num_frames = max([x.frame_idx for x in labels]) + 1 | ||
# If there is metadata available for the video, use that | ||
if video.shape: | ||
num_frames = max(num_frames, video.shape[0]) | ||
num_keypoints = [len(x.nodes) for x in all_labels.skeletons if x.name == "Mouse"][0] | ||
num_mice = get_max_ids_in_video(labels, key="Mouse") | ||
# Note that this 1-indexes identities | ||
track_2_idx = { | ||
key: val + 1 | ||
for key, val in zip(all_labels.tracks, range(len(all_labels.tracks))) | ||
} | ||
last_unassigned_id = num_mice | ||
|
||
keypoint_mat = np.zeros([num_frames, num_mice, num_keypoints, 2], dtype=np.uint16) | ||
confidence_mat = np.zeros([num_frames, num_mice, num_keypoints], dtype=np.float32) | ||
identity_mat = np.zeros([num_frames, num_mice], dtype=np.uint32) | ||
instance_vector = np.zeros([num_frames], dtype=np.uint8) | ||
static_objects = {} | ||
|
||
# Populate the matrices with data | ||
for label in labels: | ||
assigned_instances = 0 | ||
for instance_idx, instance in enumerate(label.instances): | ||
# Static objects just get added to the object dict | ||
# This will clobber data if more than one frame is annotated | ||
if instance.skeleton.name != "Mouse": | ||
static_objects[instance.skeleton.name] = instance.numpy() | ||
continue | ||
pose = instance.numpy() | ||
if pose.shape[0] != len(JABS_DEFAULT_KEYPOINTS): | ||
warnings.warn( | ||
f"JABS format only supports 12 keypoints for mice. Skipping storage of instance on frame {label.frame_idx} with {len(instance.points)} keypoints." | ||
) | ||
continue | ||
missing_points = np.isnan(pose[:, 0]) | ||
pose[np.isnan(pose)] = 0 | ||
# JABS stores y,x for poses | ||
pose = np.flip(pose.astype(np.uint16), axis=-1) | ||
keypoint_mat[label.frame_idx, instance_idx, :, :] = pose | ||
confidence_mat[label.frame_idx, instance_idx, ~missing_points] = 1.0 | ||
if instance.track: | ||
identity_mat[label.frame_idx, instance_idx] = track_2_idx[ | ||
instance.track | ||
] | ||
else: | ||
warnings.warn( | ||
f"Pose with unassigned track found on {label.video.filename} frame {label.frame_idx} instance {instance_idx}. Assigning ID {last_unassigned_id}." | ||
) | ||
identity_mat[label.frame_idx, instance_idx] = last_unassigned_id | ||
last_unassigned_id += 1 | ||
assigned_instances += 1 | ||
instance_vector[label.frame_idx] = assigned_instances | ||
|
||
# Return the data as a dict | ||
return { | ||
"keypoints": keypoint_mat.astype(np.uint16), | ||
"confidence": confidence_mat.astype(np.float32), | ||
"identity": identity_mat.astype(np.uint32), | ||
"num_identities": instance_vector.astype(np.uint16), | ||
"static_objects": static_objects, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The convert_labels
function converts a Labels
object into JABS-formatted annotations. It handles both single mouse and multi-mouse cases, as well as static objects. However, there is a warning message indicating that JABS format only supports 12 keypoints for mice and instances with more keypoints will be skipped. This could potentially lead to data loss if the input Labels
object contains instances with more than 12 keypoints.
def write_labels(labels: Labels, pose_version: int, root_folder: str): | ||
"""Convert and save a SLEAP `Labels` object to a JABS pose file. | ||
|
||
Only supports pose version 2 (single mouse) and 3-5 (multi mouse). | ||
|
||
Args: | ||
labels: SLEAP `Labels` to be converted to JABS pose format. | ||
pose_version: JABS pose version to use when writing data. | ||
root_folder: Root folder where the jabs files should be written | ||
""" | ||
for video in labels.videos: | ||
converted_labels = convert_labels(labels, video) | ||
out_filename = ( | ||
os.path.splitext(video.filename)[0] + f"_pose_est_v{pose_version}.h5" | ||
) | ||
if root_folder: | ||
out_filename = os.path.join(root_folder, out_filename) | ||
os.makedirs(os.path.dirname(out_filename), exist_ok=True) | ||
if os.path.exists(out_filename): | ||
warnings.warn(f"Skipping {out_filename} because it already exists.") | ||
continue | ||
if pose_version == 2: | ||
write_jabs_v2(converted_labels, out_filename) | ||
elif pose_version == 3: | ||
write_jabs_v3(converted_labels, out_filename) | ||
elif pose_version == 4: | ||
write_jabs_v4(converted_labels, out_filename) | ||
elif pose_version == 5: | ||
write_jabs_v5(converted_labels, out_filename) | ||
else: | ||
raise NotImplementedError(f"Pose format {pose_version} not supported.") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The write_labels
function converts and saves a SLEAP Labels
object to a JABS pose file. It supports multiple versions of the JABS pose format, which is a good practice for ensuring compatibility with different versions of the format. However, it raises a NotImplementedError
if an unsupported pose version is provided. It would be better to provide a more informative error message indicating which versions are supported.
- raise NotImplementedError(f"Pose format {pose_version} not supported.")
+ raise ValueError(f"Unsupported pose format {pose_version}. Supported formats are 2, 3, 4, and 5.")
@talmo I think that should be it! |
Adds conversion of JABS pose files into SLEAP.
List of things supported:
a. 1 pose file per video (as per JABS format)
TODO:
Add support for our training data?-- Deferring to future PRStatic objects not handleda.
Duplicate across all frames or just carry over to first?b.
Writing out assumes all data is same without looking at skeletona.
Track support requires integer names (convert to factorize)b.
Shapes of data may not be guaranteed (eg frame matrix should match shape of video, not annotations available)c.
Attribute data not preserved... is there a way to preserve it? px_to_cm is important-- Deferring to future PRd.
Tests/example data?e.
Enforce 12 keypoint names such that data is compatible with JABS related tools (eg https://github.com/KumarLabJax/JABS-behavior-classifier and https://github.com/KumarLabJax/gaitanalysis)Run blackv3 writer needs to break apart non-continuous tracksSummary by CodeRabbit
load_jabs
andsave_jabs
functions.