Skip to content

Commit

Permalink
Refactor: Move FlexibleWorker Class to Dedicated Workers Module
Browse files Browse the repository at this point in the history
  • Loading branch information
healthonrails committed Nov 11, 2024
1 parent 11bff4e commit 66c41ed
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 75 deletions.
147 changes: 76 additions & 71 deletions annolid/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from labelme.utils import newAction
from labelme.app import MainWindow
from annolid.gui.shape import Shape
from annolid.gui.workers import FlexibleWorker
import subprocess
import requests
from PIL import ImageQt
Expand Down Expand Up @@ -77,39 +78,6 @@
LABEL_COLORMAP = imgviz.label_colormap(value=200)


class FlexibleWorker(QtCore.QObject):
start = QtCore.Signal()
finished = QtCore.Signal(object)
return_value = QtCore.Signal(object)
stop_signal = QtCore.Signal()
progress_changed = QtCore.Signal(int)

def __init__(self, function, *args, **kwargs):
super(FlexibleWorker, self).__init__()

self.function = function
self.args = args
self.kwargs = kwargs
self.stopped = False

self.stop_signal.connect(self.stop)

def run(self):
self.stopped = False
result = self.function(*self.args, **self.kwargs)
self.return_value.emit(result)
self.finished.emit(result)

def stop(self):
self.stopped = True

def is_stopped(self):
return self.stopped

def progress_callback(self, progress):
self.progress_changed.emit(progress)


class LoadFrameThread(QtCore.QObject):
"""
Thread for loading video frames.
Expand Down Expand Up @@ -1180,14 +1148,14 @@ def predict_from_next_frame(self,
or self.automatic_pause_enabled)
if "sam2_hiera" in model_name:
self.pred_worker = FlexibleWorker(
function=process_video,
task_function=process_video,
video_path=self.video_file,
frame_idx=self.frame_number,
model_config='sam2_hiera_l.yaml' if 'hiera_l' in model_name else "sam2_hiera_s.yaml",
)
else:
self.pred_worker = FlexibleWorker(
function=self.video_processor.process_video_frames,
task_function=self.video_processor.process_video_frames,
start_frame=self.frame_number+1,
end_frame=end_frame,
step=self.step_size,
Expand All @@ -1206,13 +1174,13 @@ def predict_from_next_frame(self,
"background-color: red; color: white;")
self.stop_prediction_flag = True
self.pred_worker.moveToThread(self.seg_pred_thread)
self.pred_worker.start.connect(self.pred_worker.run)
self.pred_worker.return_value.connect(
self.pred_worker.start_signal.connect(self.pred_worker.run)
self.pred_worker.result_signal.connect(
self.lost_tracking_instance)
self.pred_worker.finished.connect(self.predict_is_ready)
self.pred_worker.finished_signal.connect(self.predict_is_ready)
self.seg_pred_thread.finished.connect(
self.seg_pred_thread.quit)
self.pred_worker.start.emit()
self.pred_worker.start_signal.emit()

def lost_tracking_instance(self, message):
if message is None:
Expand Down Expand Up @@ -1604,41 +1572,78 @@ def frames(self):
self.importDirImages(out_frames_dir)

def convert_json_to_tracked_csv(self):
if self.video_file is not None:
video_file = self.video_file
out_folder = Path(video_file).with_suffix('')
if out_folder is None or not out_folder.exists():
QtWidgets.QMessageBox.about(self,
"No predictions",
"Help Annolid achieve precise predictions by labeling a frame.\
Your input is valuable!")
"""
Convert JSON annotations to a tracked CSV file and handle the progress using a separate thread.
"""
if not self.video_file:
QtWidgets.QMessageBox.warning(
self, "Missing Video File", "No video file selected.")
return

video_file = self.video_file
out_folder = Path(video_file).with_suffix('')

if not out_folder or not out_folder.exists():
QtWidgets.QMessageBox.warning(
self,
"No Predictions Found",
"Help Annolid achieve precise predictions by labeling a frame. Your input is valuable!"
)
return

def update_progress(progress):
self.progress_bar.setValue(progress)
self._initialize_progress_bar()

try:
self.worker = FlexibleWorker(
task_function=labelme2csv.convert_json_to_csv,
json_folder=str(out_folder),
progress_callback=self._update_progress_bar
)
self.thread = QtCore.QThread()

# Move the worker to the thread and connect signals
self.worker.moveToThread(self.thread)
self._connect_worker_signals()

# Safely start the thread and worker signal
self.thread.start()
# Emit in a thread-safe way
QtCore.QTimer.singleShot(
0, lambda: self.worker.start_signal.emit())

except Exception as e:
QtWidgets.QMessageBox.critical(
self, "Error", f"An unexpected error occurred: {str(e)}")
finally:
self.statusBar().removeWidget(self.progress_bar)

def _initialize_progress_bar(self):
"""Initialize the progress bar and add it to the status bar."""
self.progress_bar.setValue(0)
self.statusBar().addWidget(self.progress_bar)

self.worker = FlexibleWorker(
labelme2csv.convert_json_to_csv, str(out_folder),
progress_callback=update_progress)
self.thread = QtCore.QThread()
self.worker.moveToThread(self.thread)
self.worker.start.connect(self.worker.run)
self.worker.finished.connect(self.place_preference_analyze_auto)
self.worker.finished.connect(self.thread.quit)
self.worker.finished.connect(self.worker.deleteLater)
self.thread.finished.connect(self.thread.deleteLater)
self.worker.finished.connect(lambda:
QtWidgets.QMessageBox.about(self,
"Tracking results are ready.",
f"Kindly review the file here: {str(out_folder) + '.csv'}."))
self.worker.progress_changed.connect(update_progress)

self.thread.start()
self.worker.start.emit()
self.statusBar().removeWidget(self.progress_bar)
def _update_progress_bar(self, progress):
"""Update the progress bar's value."""
self.progress_bar.setValue(progress)

def _connect_worker_signals(self):
"""Connect worker signals to their respective slots safely."""
self.worker.start_signal.connect(self.worker.run)
self.worker.finished_signal.connect(self.place_preference_analyze_auto)

# Ensure cleanup happens in the right thread
self.worker.finished_signal.connect(self.thread.quit)
self.worker.finished_signal.connect(lambda: self.worker.deleteLater())
self.thread.finished.connect(lambda: self.thread.deleteLater())

self.worker.finished_signal.connect(
lambda: QtWidgets.QMessageBox.information(
self,
"Tracking Complete",
f"Kindly review the file here: {Path(self.video_file).with_suffix('.csv')}"
)
)
self.worker.progress_signal.connect(self._update_progress_bar)

def tracks(self):
"""
Expand Down Expand Up @@ -1821,10 +1826,10 @@ def models(self):
process = start_tensorboard(log_dir=out_runs_dir)
try:
self.seg_train_thread.start()
train_worker = FlexibleWorker(function=segmentor.train)
train_worker = FlexibleWorker(task_function=segmentor.train)
train_worker.moveToThread(self.seg_train_thread)
train_worker.start.connect(train_worker.run)
train_worker.start.emit()
train_worker.start_signal.connect(train_worker.run)
train_worker.start_signal.emit()
except Exception:
segmentor.train()

Expand Down
4 changes: 2 additions & 2 deletions annolid/gui/label_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ def save(
caption=caption,
)
for key, value in otherData.items():
assert key not in data
data[key] = value
if key not in data:
data[key] = value
try:
with open(filename, "w") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
Expand Down
1 change: 1 addition & 0 deletions annolid/gui/widgets/caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class CaptionWidget(QtWidgets.QWidget):
charInserted = Signal(str) # Signal emitted when a character is inserted
charDeleted = Signal(str) # Signal emitted when a character is deleted
readCaptionFinished = Signal() # Define a custom signal
imageNotFound = Signal(str)

def __init__(self, parent=None):
super().__init__(parent)
Expand Down
67 changes: 67 additions & 0 deletions annolid/gui/workers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from qtpy import QtCore


class FlexibleWorker(QtCore.QObject):
"""
A flexible worker class that runs a given function in a separate thread.
Provides signals to indicate the start, progress, return value, and completion of the task.
"""

start_signal = QtCore.Signal()
finished_signal = QtCore.Signal(object)
result_signal = QtCore.Signal(object)
stop_signal = QtCore.Signal()
progress_signal = QtCore.Signal(int)

def __init__(self, task_function, *args, **kwargs):
"""
Initialize the FlexibleWorker with the function to run and its arguments.
:param task_function: The function to be executed.
:param args: Positional arguments for the function.
:param kwargs: Keyword arguments for the function.
"""
super().__init__()
self._task_function = task_function
self._args = args
self._kwargs = kwargs
self._is_stopped = False

# Connect the stop signal to the stop method
self.stop_signal.connect(self._stop)

def run(self):
"""
Executes the task function with the provided arguments.
Emits signals for result and completion when done.
"""
self._is_stopped = False
try:
result = self._task_function(*self._args, **self._kwargs)
self.result_signal.emit(result)
self.finished_signal.emit(result)
except Exception as e:
# Optionally handle exceptions and emit an error signal if needed
self.finished_signal.emit(e)

def _stop(self):
"""
Stops the worker by setting the stop flag.
"""
self._is_stopped = True

def is_stopped(self):
"""
Check if the worker has been stopped.
:return: True if the worker is stopped, otherwise False.
"""
return self._is_stopped

def report_progress(self, progress):
"""
Reports the progress of the task.
:param progress: An integer representing the progress percentage.
"""
self.progress_signal.emit(progress)
4 changes: 2 additions & 2 deletions annolid/segmentation/cutie_vos/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _initialize_model(self):
logger.info(f"Tmax: max_mem_frames: {self.max_mem_frames}")
cutie_model = CUTIE(cfg).to(self.device).eval()
model_weights = torch.load(
cfg.weights, map_location=self.device)
cfg.weights, map_location=self.device, weights_only=True)
cutie_model.load_weights(model_weights)
return cutie_model, cfg

Expand Down Expand Up @@ -271,7 +271,7 @@ def commit_masks_into_permanent_memory(self, frame_number, labels_dict):
dict: Updated labels dictionary.
"""
with torch.inference_mode():
with torch.amp.autocast('cuda',enabled=self.cfg.amp and self.device == 'cuda'):
with torch.amp.autocast('cuda', enabled=self.cfg.amp and self.device == 'cuda'):
png_file_paths = glob.glob(
f"{self.video_folder}/{self.video_folder.name}_0*.png")
png_file_paths = [p for p in png_file_paths if 'mask' not in p]
Expand Down

0 comments on commit 66c41ed

Please sign in to comment.