Skip to content

Commit

Permalink
run spinner when saving targets
Browse files Browse the repository at this point in the history
  • Loading branch information
thorstenwagner committed Dec 4, 2023
1 parent 991d898 commit 753813e
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 20 deletions.
15 changes: 15 additions & 0 deletions src/napari_tomotwin/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import napari
from pyqtspinner import WaitingSpinner
from qtpy.QtCore import Qt
from qtpy.QtGui import QColor # pylint: disable=E0611


def make_spinner():
return WaitingSpinner(napari.current_viewer().window._qt_window,
True,
True, Qt.ApplicationModal,
color=QColor(255, 255, 255),
fade=60,
line_width=5,
line_length=15,
)
14 changes: 3 additions & 11 deletions src/napari_tomotwin/load_umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from qtpy.QtGui import QGuiApplication, QColor # pylint: disable=E0611
from typing import List
from pyqtspinner import WaitingSpinner
from napari.qt.threading import thread_worker

plotter_widget: PlotterWidget = None
circles: List[Circle] = []
Expand Down Expand Up @@ -46,7 +47,6 @@ def _draw_circle(data_coordinates, label_layer, umap):
plotter_widget.graphics_widget.axes.add_patch(circle)
plotter_widget.graphics_widget.draw_idle()

from napari.qt.threading import thread_worker

@thread_worker()
def run_clusters_plotter(plotter_widget,
Expand Down Expand Up @@ -115,15 +115,7 @@ def get_event(viewer, event):
data_coordinates = label_layer.world_to_data(event.position)
_draw_circle(data_coordinates,label_layer,umap)

def make_spinner():
return WaitingSpinner(napari.current_viewer().window._qt_window,
True,
True, Qt.ApplicationModal,
color=QColor(255, 255, 255),
fade=60,
line_width=5,
line_length=15,
)

def stop_spinner():
spinner.stop()

Expand All @@ -133,7 +125,7 @@ def load_umap(label_layer: "napari.layers.Labels",
global plotter_widget
global spinner


from napari_tomotwin.common import make_spinner
spinner = make_spinner()
spinner.start() # starts spinning

Expand Down
37 changes: 28 additions & 9 deletions src/napari_tomotwin/make_targets.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import os
import pathlib
import sys
from typing import List, Tuple, Literal, Callable

import numpy as np
import numpy.typing as npt
import pandas as pd
from magicgui import magic_factory
from scipy.spatial.distance import cdist
from napari.qt.threading import thread_worker

global spinner

def get_non_numeric_column_titles(df: pd.DataFrame):

Expand All @@ -33,7 +35,6 @@ def _get_avg_embedding(embeddings: pd.DataFrame) -> Tuple[pd.DataFrame, npt.Arra
target = only_emb.mean(axis=0)
return target, np.array([])


def _make_targets(embeddings: pd.DataFrame, clusters: pd.DataFrame, avg_func: Callable[[pd.DataFrame], npt.ArrayLike]) -> Tuple[pd.DataFrame, List[pd.DataFrame], dict]:
targets = []
sub_embeddings = []
Expand Down Expand Up @@ -62,6 +63,7 @@ def _make_targets(embeddings: pd.DataFrame, clusters: pd.DataFrame, avg_func: Ca
return targets, sub_embeddings, target_locations



def _run(clusters,
embeddings: pd.DataFrame,
output_folder: pathlib.Path,
Expand All @@ -73,9 +75,12 @@ def _run(clusters,
if average_method_name == "Average":
avg_method = _get_avg_embedding



print("Make targets")
#embeddings = embeddings.reset_index()
embeddings = embeddings.drop(columns=["level_0","index"], errors="ignore")

targets, sub_embeddings, target_locations = _make_targets(embeddings, clusters, avg_func=avg_method)

print("Write targets")
Expand All @@ -98,6 +103,20 @@ def _run(clusters,

print("Done")


@thread_worker
def _run_worker(embeddings_filepath, label_layer, output_folder: str, average_method_name: str):
print("Read clusters")
clusters = label_layer.features['MANUAL_CLUSTER_ID']

print("Read embeddings")
embeddings = pd.read_pickle(embeddings_filepath)
_run(clusters, embeddings, output_folder, average_method_name)

def stop_spinner():
global spinner
spinner.stop()

@magic_factory(
call_button="Save",
label_layer={'label': 'TomoTwin Label Mask:'},
Expand All @@ -115,11 +134,11 @@ def make_targets(
output_folder: pathlib.Path,
average_method_name: Literal["Average", "Medoid"] = "Medoid",
):
global spinner
from napari_tomotwin.common import make_spinner
spinner = make_spinner()
spinner.start() # starts spinning
worker = _run_worker(embeddings_filepath, label_layer, output_folder, average_method_name) # create "worker" object
worker.returned.connect(stop_spinner)
worker.start()

print("Read clusters")
clusters = label_layer.features['MANUAL_CLUSTER_ID']

print("Read embeddings")
embeddings = pd.read_pickle(embeddings_filepath)

_run(clusters, embeddings, output_folder, average_method_name)

0 comments on commit 753813e

Please sign in to comment.