Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import pose #1225

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
- Set `probe_id` as `probe_description` when inserting from nwb file #1220
- Default `AnalysisNwbfile.create` permissions are now 777 #1226
- Position
- Allow population of missing `PositionIntervalMap` entries during population of `DLCPoseEstimation` #1208
- Allow population of missing `PositionIntervalMap` entries during population
of `DLCPoseEstimation` #1208
- Enable import of existing pose data to `ImportedPose` in position pipeline #1225

## [0.5.4] (December 20, 2024)

Expand Down
8 changes: 8 additions & 0 deletions docs/src/ForDevelopers/UsingNWB.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,14 @@ hdmf.common.table.DynamicTable </b>
| PositionSource.SpatialSeries | id | int(nwbf.processing.behavior.position.\[index\]) (the enumerated index number) | | |
| RawPosition.PosObject | raw_position_object_id | nwbf.processing.behavior.position.\[index\].object_id | | |

<b> NWBfile Location: nwbf.processing.behavior.PoseEstimation </br> Object type:
(ndx_pose.PoseEstimation) </b>

| Spyglass Table | Key | NWBfile Location | Config option | Notes |
| :--------------------------- | :--------------------: | -------------------------------------------------------------------------------: | ------------: | --------------------: |
| ImportedPose | interval_list_name | pose_{PoseEstimation.name}_valid_times |
| ImportedPose.BodyPart | pose | nwbf.processing.behavior.PoseEstimation.pose_estimation_series.name |

<b> NWBfile Location: nwbf.processing.video_files.video </br> Object type:
pynwb.image.ImageSeries </b>

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ authors = [
{ name = "Ryan Ly", email = "rly@lbl.gov" },
{ name = "Daniel Gramling", email = "daniel.gramling@ucsf.edu" },
{ name = "Chris Brozdowski", email = "chris.broz@ucsf.edu" },
{ name = "Samuel Bray", email = "sam.bray@ucsf.edu" },
]
classifiers = [
"Programming Language :: Python :: 3",
Expand Down Expand Up @@ -45,6 +46,7 @@ dependencies = [
"ipympl",
"matplotlib",
"ndx_franklab_novela>=0.1.0",
"ndx-pose",
"non_local_detector",
"numpy",
"opencv-python",
Expand Down
4 changes: 4 additions & 0 deletions src/spyglass/common/populate_all_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def single_transaction_make(
if table.__name__ == "PositionSource":
# PositionSource only uses nwb_file_name - full calls redundant
key_source = dj.U("nwb_file_name") & key_source
if table.__name__ == "ImportedPose":
key_source = Nwbfile()

for pop_key in (key_source & file_restr).fetch("KEY"):
try:
Expand Down Expand Up @@ -116,6 +118,7 @@ def populate_all_common(
List
A list of keys for InsertError entries if any errors occurred.
"""
from spyglass.position.v1.imported_pose import ImportedPose
from spyglass.spikesorting.imported import ImportedSpikeSorting

declare_all_merge_tables()
Expand Down Expand Up @@ -143,6 +146,7 @@ def populate_all_common(
PositionSource, # Depends on Session
VideoFile, # Depends on TaskEpoch
StateScriptFile, # Depends on TaskEpoch
ImportedPose, # Depends on Session
],
[
RawPosition, # Depends on PositionSource
Expand Down
13 changes: 13 additions & 0 deletions src/spyglass/position/position_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pandas import DataFrame

from spyglass.common.common_position import IntervalPositionInfo as CommonPos
from spyglass.position.v1.imported_pose import ImportedPose
from spyglass.position.v1.position_dlc_selection import DLCPosV1
from spyglass.position.v1.position_trodes_position import TrodesPosV1
from spyglass.utils import SpyglassMixin, _Merge
Expand All @@ -13,6 +14,7 @@
"IntervalPositionInfo": CommonPos,
"DLCPosV1": DLCPosV1,
"TrodesPosV1": TrodesPosV1,
"ImportedPose": ImportedPose,
}


Expand Down Expand Up @@ -63,6 +65,17 @@ class CommonPos(SpyglassMixin, dj.Part):
-> CommonPos
"""

class ImportedPose(SpyglassMixin, dj.Part):
"""
Table to pass-through upstream Pose information from NWB file
"""

definition = """
-> PositionOutput
---
-> ImportedPose
"""

def fetch1_dataframe(self) -> DataFrame:
"""Fetch a single dataframe from the merged table."""
# proj replaces operator restriction to enable
Expand Down
152 changes: 152 additions & 0 deletions src/spyglass/position/v1/imported_pose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import datajoint as dj
import ndx_pose
import numpy as np
import pandas as pd
import pynwb

from spyglass.common import IntervalList, Nwbfile
from spyglass.utils.dj_mixin import SpyglassMixin
from spyglass.utils.nwb_helper_fn import (
estimate_sampling_rate,
get_valid_intervals,
)

schema = dj.schema("position_v1_imported_pose")


@schema
class ImportedPose(SpyglassMixin, dj.Manual):
"""
Table to ingest pose data generated prior to spyglass.
Each entry corresponds to on ndx_pose.PoseEstimation object in an NWB file.
PoseEstimation objects should be stored in nwb.processing.behavior
Assumptions:
- Single skeleton object per PoseEstimation object
"""

_nwb_table = Nwbfile

definition = """
-> IntervalList
---
pose_object_id: varchar(80) # unique identifier for the pose object
skeleton_object_id: varchar(80) # unique identifier for the skeleton object
"""

class BodyPart(SpyglassMixin, dj.Part):
definition = """
-> master
part_name: varchar(80)
---
part_object_id: varchar(80)
"""

def make(self, key):
self.insert_from_nwbfile(key["nwb_file_name"])

def insert_from_nwbfile(self, nwb_file_name):
file_path = Nwbfile().get_abs_path(nwb_file_name)
interval_keys = []
master_keys = []
part_keys = []
with pynwb.NWBHDF5IO(file_path, mode="r") as io:
nwb = io.read()
behavior_module = nwb.get_processing_module("behavior")

# Loop through all the PoseEstimation objects in the behavior module
for name, obj in behavior_module.data_interfaces.items():
if not isinstance(obj, ndx_pose.PoseEstimation):
continue

# use the timestamps from the first body part to define valid times
timestamps = list(obj.pose_estimation_series.values())[
0
].timestamps[:]
sampling_rate = estimate_sampling_rate(
timestamps, filename=nwb_file_name
)
valid_intervals = get_valid_intervals(
timestamps,
sampling_rate=sampling_rate,
min_valid_len=sampling_rate,
)
interval_key = {
"nwb_file_name": nwb_file_name,
"interval_list_name": f"pose_{name}_valid_intervals",
"valid_times": valid_intervals,
"pipeline": "ImportedPose",
}
interval_keys.append(interval_key)

# master key
master_key = {
"nwb_file_name": nwb_file_name,
"interval_list_name": interval_key["interval_list_name"],
"pose_object_id": obj.object_id,
"skeleton_object_id": obj.skeleton.object_id,
}
master_keys.append(master_key)

# part keys
for part, part_obj in obj.pose_estimation_series.items():
part_key = {
"nwb_file_name": nwb_file_name,
"interval_list_name": interval_key[
"interval_list_name"
],
"part_name": part,
"part_object_id": part_obj.object_id,
}
part_keys.append(part_key)

IntervalList().insert(interval_keys, skip_duplicates=True)
self.insert(master_keys, skip_duplicates=True)
self.BodyPart().insert(part_keys, skip_duplicates=True)

def fetch_pose_dataframe(self, key=dict()):
"""Fetch pose data as a pandas DataFrame

Parameters
----------
key : dict
Key to fetch pose data for

Returns
-------
pd.DataFrame
DataFrame containing pose data
"""
key = (self & key).fetch1("KEY")
pose_estimations = (
(self & key).fetch_nwb()[0]["pose"].pose_estimation_series
)

index = None
pose_df = {}
for body_part in pose_estimations.keys():
if index is None:
index = pd.Index(
pose_estimations[body_part].timestamps[:],
name="time",
)

part_df = {
"video_frame_ind": np.nan,
"x": pose_estimations[body_part].data[:, 0],
"y": pose_estimations[body_part].data[:, 1],
"likelihood": pose_estimations[body_part].confidence[:],
}

pose_df[body_part] = pd.DataFrame(part_df, index=index)

pose_df
return pd.concat(pose_df, axis=1)

def fetch_skeleton(self, key=dict()):
nwb = (self & key).fetch_nwb()[0]
nodes = nwb["skeleton"].nodes[:]
int_edges = nwb["skeleton"].edges[:]
named_edges = [[nodes[i], nodes[j]] for i, j in int_edges]
named_edges
skeleton = {"nodes": nodes, "edges": named_edges}
return skeleton
Loading