Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qt widget for loading pose datasets as napari Points layers #253

Merged
merged 50 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
bae33b9
initialise napari plugin development
niksirbi Jun 12, 2024
cf5d66e
Create skeleton for napari plugin with collapsible widgets (#218)
niksirbi Jul 17, 2024
d3d3735
initialise napari plugin development
niksirbi Jun 12, 2024
258b530
initialise napari plugin development
niksirbi Jun 12, 2024
d332829
initialise napari plugin development
niksirbi Jun 12, 2024
f4b3e39
Added loader widget for poses
niksirbi Jul 30, 2024
88e71a0
update widget tests
niksirbi Jul 30, 2024
e770fc9
simplify dependency on brainglobe-utils
niksirbi Sep 2, 2024
f880d7a
consistent monospace formatting for movement in public docstrings
niksirbi Sep 2, 2024
4c32fac
get rid of code that's only relevant for displaying Tracks
niksirbi Sep 2, 2024
1375164
enable visibility of napari layer tooltips
niksirbi Sep 2, 2024
1abec3e
renamed widget to PosesLoader
niksirbi Sep 2, 2024
3247769
make cmap optional in set_color_by method
niksirbi Sep 3, 2024
b7e3cbd
wrote unit tests for napari convert module
niksirbi Sep 3, 2024
955341c
wrote unit-tests for the layer styles module
niksirbi Sep 12, 2024
dab446c
linkcheck ignore zenodo redirects
niksirbi Sep 12, 2024
07c0272
move _sample_colormap out of PointsStyle class
niksirbi Sep 13, 2024
5c1d04f
small refactoring in the loader widget
niksirbi Sep 13, 2024
200a166
Expand tests for loader widget
niksirbi Sep 13, 2024
3caca4e
added comments and docstrings to napari plugin tests
niksirbi Sep 16, 2024
51fb610
refactored all napari tests into separate unit test folder
niksirbi Sep 16, 2024
d917395
added napari-video to dependencies
niksirbi Sep 16, 2024
46850f7
replaced deprecated edge_width with border_width
niksirbi Sep 16, 2024
7dea728
got rid of widget pytest fixtures
niksirbi Sep 16, 2024
e6919a9
remove duplicate word from docstring
niksirbi Sep 16, 2024
93c116b
remove napari-video dependency
niksirbi Oct 4, 2024
0fe5b08
include napari extras in docs requirements
niksirbi Oct 10, 2024
5947a2e
add test for _on_browse_clicked method
niksirbi Nov 4, 2024
8962512
getOpenFileName returns tuple, not str
niksirbi Nov 4, 2024
3f1771d
simplify poses_to_napari_tracks
niksirbi Nov 19, 2024
8a9017f
[pre-commit.ci] pre-commit autoupdate (#338)
pre-commit-ci[bot] Nov 4, 2024
0c2917b
Implement `compute_speed` and `compute_path_length` (#280)
niksirbi Nov 5, 2024
61c643a
initialise napari plugin development
niksirbi Jun 12, 2024
a86fa93
initialise napari plugin development
niksirbi Jun 12, 2024
279e4f5
initialise napari plugin development
niksirbi Jun 12, 2024
46cf913
initialise napari plugin development
niksirbi Jun 12, 2024
62c87c7
initialise napari plugin development
niksirbi Jun 12, 2024
1380747
avoid redefining duplicate attributes in child dataclass
niksirbi Nov 19, 2024
2ce361c
modify test case to match poses_to_napari_tracks simplification
niksirbi Nov 19, 2024
f92245e
expected_log_messages should be a subset of captured messages
niksirbi Nov 19, 2024
21c2b32
fix typo
niksirbi Nov 19, 2024
9cdcfc3
use names for Qwidgets
niksirbi Nov 19, 2024
de46d4d
reorganised test_valid_poses_to_napari_tracks
niksirbi Nov 19, 2024
2a6cc97
parametrised layer style tests
niksirbi Nov 19, 2024
8a400f8
Merge branch 'napari-dev' into napari-loader-widget
niksirbi Nov 19, 2024
785c35c
delet integration test which was reintroduced after conflict resolution
niksirbi Nov 19, 2024
6483915
added test about file filters
niksirbi Nov 19, 2024
5cb177d
deleted obsolete loader widget file (had snuck back in due to conflic…
niksirbi Nov 19, 2024
9bbbca9
combine tests for button callouts
niksirbi Nov 20, 2024
faa1549
Simplify test_layer_style_as_kwargs
niksirbi Nov 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
-e .
-e .[napari]
linkify-it-py
myst-parser
nbsphinx
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@
"https://opensource.org/license/bsd-3-clause/", # to avoid odd 403 error
]


myst_url_schemes = {
"http": None,
"https": None,
Expand Down
32 changes: 0 additions & 32 deletions movement/napari/_loader_widget.py

This file was deleted.

161 changes: 161 additions & 0 deletions movement/napari/_loader_widgets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""Widgets for loading movement datasets from file."""

import logging
from pathlib import Path

from napari.settings import get_settings
from napari.utils.notifications import show_warning
from napari.viewer import Viewer
from qtpy.QtWidgets import (
QComboBox,
QFileDialog,
QFormLayout,
QHBoxLayout,
QLineEdit,
QPushButton,
QSpinBox,
QWidget,
)

from movement.io import load_poses
from movement.napari.convert import poses_to_napari_tracks
from movement.napari.layer_styles import PointsStyle

logger = logging.getLogger(__name__)

# Allowed poses file suffixes for each supported source software
SUPPORTED_POSES_FILES = {
"DeepLabCut": ["*.h5", "*.csv"],
"LightningPose": ["*.csv"],
"SLEAP": ["*.h5", "*.slp"],
}


class PosesLoader(QWidget):
"""Widget for loading movement poses datasets from file."""

def __init__(self, napari_viewer: Viewer, parent=None):
"""Initialize the loader widget."""
super().__init__(parent=parent)
self.viewer = napari_viewer
self.setLayout(QFormLayout())
# Create widgets
self._create_source_software_widget()
self._create_fps_widget()
self._create_file_path_widget()
self._create_load_button()
# Enable layer tooltips from napari settings
self._enable_layer_tooltips()

def _create_source_software_widget(self):
"""Create a combo box for selecting the source software."""
self.source_software_combo = QComboBox()
self.source_software_combo.setObjectName("source_software_combo")
self.source_software_combo.addItems(SUPPORTED_POSES_FILES.keys())
self.layout().addRow("source software:", self.source_software_combo)

def _create_fps_widget(self):
"""Create a spinbox for selecting the frames per second (fps)."""
self.fps_spinbox = QSpinBox()
self.fps_spinbox.setObjectName("fps_spinbox")
self.fps_spinbox.setMinimum(1)
self.fps_spinbox.setMaximum(1000)
self.fps_spinbox.setValue(30)
self.layout().addRow("fps:", self.fps_spinbox)

def _create_file_path_widget(self):
"""Create a line edit and browse button for selecting the file path.

This allows the user to either browse the file system,
or type the path directly into the line edit.
"""
# File path line edit and browse button
self.file_path_edit = QLineEdit()
self.file_path_edit.setObjectName("file_path_edit")
self.browse_button = QPushButton("Browse")
self.browse_button.setObjectName("browse_button")
self.browse_button.clicked.connect(self._on_browse_clicked)
# Layout for line edit and button
self.file_path_layout = QHBoxLayout()
self.file_path_layout.addWidget(self.file_path_edit)
self.file_path_layout.addWidget(self.browse_button)
self.layout().addRow("file path:", self.file_path_layout)

def _create_load_button(self):
"""Create a button to load the file and add layers to the viewer."""
self.load_button = QPushButton("Load")
self.load_button.setObjectName("load_button")
self.load_button.clicked.connect(lambda: self._on_load_clicked())
self.layout().addRow(self.load_button)

def _on_browse_clicked(self):
"""Open a file dialog to select a file."""
file_suffixes = SUPPORTED_POSES_FILES[
self.source_software_combo.currentText()
]

file_path, _ = QFileDialog.getOpenFileName(
self,
caption="Open file containing predicted poses",
filter=f"Poses files ({' '.join(file_suffixes)})",
)

# A blank string is returned if the user cancels the dialog
if not file_path:
return

# Add the file path to the line edit (text field)
self.file_path_edit.setText(file_path)

def _on_load_clicked(self):
"""Load the file and add as a Points layer to the viewer."""
fps = self.fps_spinbox.value()
source_software = self.source_software_combo.currentText()
file_path = self.file_path_edit.text()
if file_path == "":
show_warning("No file path specified.")
return
ds = load_poses.from_file(file_path, source_software, fps)

self.data, self.props = poses_to_napari_tracks(ds)
logger.info("Converted poses dataset to a napari Tracks array.")
logger.debug(f"Tracks array shape: {self.data.shape}")

self.file_name = Path(file_path).name
self._add_points_layer()

self._set_playback_fps(fps)
logger.debug(f"Set napari playback speed to {fps} fps.")

def _add_points_layer(self):
"""Add the predicted poses to the viewer as a Points layer."""
# Style properties for the napari Points layer
points_style = PointsStyle(
name=f"poses: {self.file_name}",
properties=self.props,
)
# Color the points by individual if there are multiple individuals
# Otherwise, color by keypoint
n_individuals = len(self.props["individual"].unique())
points_style.set_color_by(
prop="individual" if n_individuals > 1 else "keypoint"
)
# Add the points layer to the viewer
self.viewer.add_points(self.data[:, 1:], **points_style.as_kwargs())
logger.info("Added poses dataset as a napari Points layer.")

@staticmethod
def _set_playback_fps(fps: int):
"""Set the playback speed for the napari viewer."""
settings = get_settings()
settings.application.playback_fps = fps

@staticmethod
def _enable_layer_tooltips():
"""Toggle on tooltip visibility for napari layers.

This nicely displays the layer properties as a tooltip
when hovering over the layer in the napari viewer.
"""
settings = get_settings()
settings.appearance.layer_tooltip_visibility = True
6 changes: 3 additions & 3 deletions movement/napari/_meta_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from brainglobe_utils.qtpy.collapsible_widget import CollapsibleWidgetContainer
from napari.viewer import Viewer

from movement.napari._loader_widget import Loader
from movement.napari._loader_widgets import PosesLoader


class MovementMetaWidget(CollapsibleWidgetContainer):
Expand All @@ -18,9 +18,9 @@ def __init__(self, napari_viewer: Viewer, parent=None):
super().__init__()

self.add_widget(
Loader(napari_viewer, parent=self),
PosesLoader(napari_viewer, parent=self),
collapsible=True,
widget_title="Load data",
widget_title="Load poses",
)

self.loader = self.collapsible_widgets[0]
Expand Down
73 changes: 73 additions & 0 deletions movement/napari/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Conversion functions from ``movement`` datasets to napari layers."""

import logging

import numpy as np
import pandas as pd
import xarray as xr

# get logger
logger = logging.getLogger(__name__)


def _construct_properties_dataframe(ds: xr.Dataset) -> pd.DataFrame:
"""Construct a properties DataFrame from a ``movement`` dataset."""
return pd.DataFrame(
{
"individual": ds.coords["individuals"].values,
"keypoint": ds.coords["keypoints"].values,
"time": ds.coords["time"].values,
"confidence": ds["confidence"].values.flatten(),
}
)


def poses_to_napari_tracks(ds: xr.Dataset) -> tuple[np.ndarray, pd.DataFrame]:
"""Convert poses dataset to napari Tracks array and properties.

Parameters
----------
ds : xr.Dataset
``movement`` dataset containing pose tracks, confidence scores,
and associated metadata.

Returns
-------
data : np.ndarray
napari Tracks array with shape (N, 4),
where N is n_keypoints * n_individuals * n_frames
and the 4 columns are (track_id, frame_idx, y, x).
properties : pd.DataFrame
DataFrame with properties (individual, keypoint, time, confidence).

Notes
-----
A corresponding napari Points array can be derived from the Tracks array
by taking its last 3 columns: (frame_idx, y, x). See the documentation
on the napari Tracks [1]_ and Points [2]_ layers.

References
----------
.. [1] https://napari.org/stable/howtos/layers/tracks.html
.. [2] https://napari.org/stable/howtos/layers/points.html

"""
n_frames = ds.sizes["time"]
n_individuals = ds.sizes["individuals"]
n_keypoints = ds.sizes["keypoints"]
n_tracks = n_individuals * n_keypoints
# Construct the napari Tracks array
# Reorder axes to (individuals, keypoints, frames, xy)
yx_cols = np.transpose(ds.position.values, (1, 2, 0, 3)).reshape(-1, 2)[
:, [1, 0] # swap x and y columns
]
# Each keypoint of each individual is a separate track
track_id_col = np.repeat(np.arange(n_tracks), n_frames).reshape(-1, 1)
time_col = np.tile(np.arange(n_frames), (n_tracks)).reshape(-1, 1)
data = np.hstack((track_id_col, time_col, yx_cols))
# Construct the properties DataFrame
# Stack 3 dimensions into a new single dimension named "tracks"
ds_ = ds.stack(tracks=("individuals", "keypoints", "time"))
properties = _construct_properties_dataframe(ds_)

return data, properties
64 changes: 64 additions & 0 deletions movement/napari/layer_styles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Dataclasses containing layer styles for napari."""

from dataclasses import dataclass, field

import numpy as np
import pandas as pd
from napari.utils.colormaps import ensure_colormap

DEFAULT_COLORMAP = "turbo"


@dataclass
class LayerStyle:
"""Base class for napari layer styles."""

name: str
properties: pd.DataFrame
visible: bool = True
blending: str = "translucent"

def as_kwargs(self) -> dict:
"""Return the style properties as a dictionary of kwargs."""
return self.__dict__


@dataclass
class PointsStyle(LayerStyle):
"""Style properties for a napari Points layer."""

symbol: str = "disc"
size: int = 10
border_width: int = 0
face_color: str | None = None
face_color_cycle: list[tuple] | None = None
face_colormap: str = DEFAULT_COLORMAP
text: dict = field(default_factory=lambda: {"visible": False})

def set_color_by(self, prop: str, cmap: str | None = None) -> None:
"""Set the face_color to a column in the properties DataFrame.

Parameters
----------
prop : str
The column name in the properties DataFrame to color by.
cmap : str, optional
The name of the colormap to use, otherwise use the face_colormap.

"""
if cmap is None:
cmap = self.face_colormap
self.face_color = prop
self.text["string"] = prop
n_colors = len(self.properties[prop].unique())
self.face_color_cycle = _sample_colormap(n_colors, cmap)


def _sample_colormap(n: int, cmap_name: str) -> list[tuple]:
"""Sample n equally-spaced colors from a napari colormap.

This includes the endpoints of the colormap.
"""
cmap = ensure_colormap(cmap_name)
samples = np.linspace(0, len(cmap.colors) - 1, n).astype(int)
return [tuple(cmap.colors[i]) for i in samples]
8 changes: 2 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,8 @@ entry-points."napari.manifest".movement = "movement.napari:napari.yaml"

[project.optional-dependencies]
napari = [
"napari[all]>=0.4.19",
# the rest will be replaced by brainglobe-utils[qt]>=0.6 after release
"brainglobe-atlasapi>=2.0.7",
"brainglobe-utils>=0.5",
"qtpy",
"superqt",
"napari[all]>=0.5.0",
"brainglobe-utils[qt]>=0.6" # needed for collapsible widgets
]
dev = [
"pytest",
Expand Down
Loading
Loading