Skip to content

Commit

Permalink
first draft of loading function
Browse files Browse the repository at this point in the history
  • Loading branch information
vigji committed Dec 6, 2024
1 parent 1387cc1 commit 86b2a77
Showing 1 changed file with 64 additions and 0 deletions.
64 changes: 64 additions & 0 deletions movement/io/load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,3 +696,67 @@ def _ds_from_valid_data(data: ValidPosesDataset) -> xr.Dataset:
"ds_type": "poses",
},
)


def from_anipose_df(anipose_triangulation_df, individual_name="individual_0"):
"""Convert triangulation dataframe to xarray dataset.
Reshape dataframe with columns keypoint1_x, keypoint1_y, keypoint1_z, keypoint1_confidence_score,
keypoint2_x, keypoint2_y, keypoint2_z, keypoint2_confidence_score, ...
to array of positions with dimensions time, individuals, keypoints, space,
and array of confidence scores with dimensions time, individuals, keypoints
Parameters
----------
anipose_triangulation_df : pd.DataFrame
Anipose triangulation dataframe
individual_name : str, optional
Name of the individual, by default "individual_0"
Returns
-------
xarray.Dataset
``movement`` dataset containing the pose tracks, confidence scores,
and associated metadata.
"""
keypoint_names = sorted(list(set([col.rsplit('_', 1)[0] for col in anipose_triangulation_df.columns
if any(col.endswith(f'_{s}') for s in ['x','y','z'])])))

n_frames = len(anipose_triangulation_df)
n_keypoints = len(keypoint_names)

# Initialize arrays and fill
position_array = np.zeros((n_frames, 1, n_keypoints, 3)) # 1 for single individual
confidence_array = np.zeros((n_frames, 1, n_keypoints))
for i, kp in enumerate(keypoint_names):
for j, coord in enumerate(['x', 'y', 'z']):
position_array[:, 0, i, j] = anipose_triangulation_df[f'{kp}_{coord}']
confidence_array[:, 0, i] = anipose_triangulation_df[f'{kp}_score']

individual_names = [individual_name]

return from_numpy(position_array=position_array,
confidence_array=confidence_array,
individual_names=individual_names,
keypoint_names=keypoint_names,
source_software="anipose_triangulation")


def from_anipose_csv(anipose_csv_path, individual_name="individual_0"):
"""Convert anipose csv to xarray dataset.
Parameters
----------
anipose_csv_path : pathlib.Path
Path to the Anipose triangulation csv file
individual_name : str, optional
Name of the individual, by default "individual_0"
Returns
-------
xarray.Dataset
``movement`` dataset containing the pose tracks, confidence scores,
and associated metadata.
"""
anipose_triangulation_df = pd.read_csv(anipose_csv_path)
# TODO add a validator for the anipose csv file at this level?
return from_anipose_df(anipose_triangulation_df, individual_name)

0 comments on commit 86b2a77

Please sign in to comment.