From 753813e8cf30c81fef919510c14c9df44bb0b5f2 Mon Sep 17 00:00:00 2001 From: thorstenwagner Date: Mon, 4 Dec 2023 12:11:55 +0100 Subject: [PATCH] run spinner when saving targets --- src/napari_tomotwin/common.py | 15 ++++++++++++ src/napari_tomotwin/load_umap.py | 14 +++-------- src/napari_tomotwin/make_targets.py | 37 ++++++++++++++++++++++------- 3 files changed, 46 insertions(+), 20 deletions(-) create mode 100644 src/napari_tomotwin/common.py diff --git a/src/napari_tomotwin/common.py b/src/napari_tomotwin/common.py new file mode 100644 index 0000000..b41825e --- /dev/null +++ b/src/napari_tomotwin/common.py @@ -0,0 +1,15 @@ +import napari +from pyqtspinner import WaitingSpinner +from qtpy.QtCore import Qt +from qtpy.QtGui import QColor # pylint: disable=E0611 + + +def make_spinner(): + return WaitingSpinner(napari.current_viewer().window._qt_window, + True, + True, Qt.ApplicationModal, + color=QColor(255, 255, 255), + fade=60, + line_width=5, + line_length=15, + ) \ No newline at end of file diff --git a/src/napari_tomotwin/load_umap.py b/src/napari_tomotwin/load_umap.py index 60c57dc..32f406f 100644 --- a/src/napari_tomotwin/load_umap.py +++ b/src/napari_tomotwin/load_umap.py @@ -10,6 +10,7 @@ from qtpy.QtGui import QGuiApplication, QColor # pylint: disable=E0611 from typing import List from pyqtspinner import WaitingSpinner +from napari.qt.threading import thread_worker plotter_widget: PlotterWidget = None circles: List[Circle] = [] @@ -46,7 +47,6 @@ def _draw_circle(data_coordinates, label_layer, umap): plotter_widget.graphics_widget.axes.add_patch(circle) plotter_widget.graphics_widget.draw_idle() -from napari.qt.threading import thread_worker @thread_worker() def run_clusters_plotter(plotter_widget, @@ -115,15 +115,7 @@ def get_event(viewer, event): data_coordinates = label_layer.world_to_data(event.position) _draw_circle(data_coordinates,label_layer,umap) -def make_spinner(): - return WaitingSpinner(napari.current_viewer().window._qt_window, - True, - True, Qt.ApplicationModal, - color=QColor(255, 255, 255), - fade=60, - line_width=5, - line_length=15, - ) + def stop_spinner(): spinner.stop() @@ -133,7 +125,7 @@ def load_umap(label_layer: "napari.layers.Labels", global plotter_widget global spinner - + from napari_tomotwin.common import make_spinner spinner = make_spinner() spinner.start() # starts spinning diff --git a/src/napari_tomotwin/make_targets.py b/src/napari_tomotwin/make_targets.py index 42b7a5d..711ba4a 100644 --- a/src/napari_tomotwin/make_targets.py +++ b/src/napari_tomotwin/make_targets.py @@ -1,6 +1,5 @@ import os import pathlib -import sys from typing import List, Tuple, Literal, Callable import numpy as np @@ -8,6 +7,9 @@ import pandas as pd from magicgui import magic_factory from scipy.spatial.distance import cdist +from napari.qt.threading import thread_worker + +global spinner def get_non_numeric_column_titles(df: pd.DataFrame): @@ -33,7 +35,6 @@ def _get_avg_embedding(embeddings: pd.DataFrame) -> Tuple[pd.DataFrame, npt.Arra target = only_emb.mean(axis=0) return target, np.array([]) - def _make_targets(embeddings: pd.DataFrame, clusters: pd.DataFrame, avg_func: Callable[[pd.DataFrame], npt.ArrayLike]) -> Tuple[pd.DataFrame, List[pd.DataFrame], dict]: targets = [] sub_embeddings = [] @@ -62,6 +63,7 @@ def _make_targets(embeddings: pd.DataFrame, clusters: pd.DataFrame, avg_func: Ca return targets, sub_embeddings, target_locations + def _run(clusters, embeddings: pd.DataFrame, output_folder: pathlib.Path, @@ -73,9 +75,12 @@ def _run(clusters, if average_method_name == "Average": avg_method = _get_avg_embedding + + print("Make targets") #embeddings = embeddings.reset_index() embeddings = embeddings.drop(columns=["level_0","index"], errors="ignore") + targets, sub_embeddings, target_locations = _make_targets(embeddings, clusters, avg_func=avg_method) print("Write targets") @@ -98,6 +103,20 @@ def _run(clusters, print("Done") + +@thread_worker +def _run_worker(embeddings_filepath, label_layer, output_folder: str, average_method_name: str): + print("Read clusters") + clusters = label_layer.features['MANUAL_CLUSTER_ID'] + + print("Read embeddings") + embeddings = pd.read_pickle(embeddings_filepath) + _run(clusters, embeddings, output_folder, average_method_name) + +def stop_spinner(): + global spinner + spinner.stop() + @magic_factory( call_button="Save", label_layer={'label': 'TomoTwin Label Mask:'}, @@ -115,11 +134,11 @@ def make_targets( output_folder: pathlib.Path, average_method_name: Literal["Average", "Medoid"] = "Medoid", ): + global spinner + from napari_tomotwin.common import make_spinner + spinner = make_spinner() + spinner.start() # starts spinning + worker = _run_worker(embeddings_filepath, label_layer, output_folder, average_method_name) # create "worker" object + worker.returned.connect(stop_spinner) + worker.start() - print("Read clusters") - clusters = label_layer.features['MANUAL_CLUSTER_ID'] - - print("Read embeddings") - embeddings = pd.read_pickle(embeddings_filepath) - - _run(clusters, embeddings, output_folder, average_method_name)