Skip to content

Commit

Permalink
feat: Add option to create projection alongside z-axis (#919)
Browse files Browse the repository at this point in the history
In Image adjustment, add Image Projection (min, max, mean, sum) with
ability to keep ROI and Mask.
  • Loading branch information
Czaki authored Apr 1, 2023
1 parent 27f662e commit 7fd9822
Show file tree
Hide file tree
Showing 14 changed files with 244 additions and 47 deletions.
6 changes: 4 additions & 2 deletions package/PartSeg/_roi_analysis/main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,8 +661,10 @@ def closeEvent(self, event):
super().closeEvent(event)

@staticmethod
def get_project_info(file_path, image):
return ProjectTuple(file_path=file_path, image=image)
def get_project_info(file_path, image, roi_info=None):
if roi_info is None:
roi_info = ROIInfo(None)
return ProjectTuple(file_path=file_path, image=image, roi_info=roi_info)

def set_data(self, data):
self.main_menu.set_data(data)
Expand Down
11 changes: 9 additions & 2 deletions package/PartSeg/_roi_mask/main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,8 +986,15 @@ def closeEvent(self, event: QCloseEvent):
super().closeEvent(event)

@staticmethod
def get_project_info(file_path, image):
return MaskProjectTuple(file_path=file_path, image=image)
def get_project_info(file_path, image, roi_info=None):
if roi_info is None:
roi_info = ROIInfo(None)
return MaskProjectTuple(
file_path=file_path,
image=image,
roi_info=roi_info,
roi_extraction_parameters={i: None for i in roi_info.bound_info},
)

def set_data(self, data):
self.main_menu.set_data(data)
Expand Down
1 change: 1 addition & 0 deletions package/PartSeg/common_gui/image_adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(self, image: Image, transform_dict: Dict[str, TransformBase] = None
layout.addWidget(self.cancel_btn, 2, 0)
layout.addWidget(self.process_btn, 2, 2)
self.setLayout(layout)
self.setWindowTitle("Image adjustment")

def process(self):
values = self.stacked.currentWidget().get_values()
Expand Down
13 changes: 9 additions & 4 deletions package/PartSeg/common_gui/main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,19 +326,21 @@ def show_about_dialog():
AboutDialog().exec_()

@staticmethod
def get_project_info(file_path, image):
def get_project_info(file_path, image, roi_info=None):
raise NotImplementedError

def image_adjust_exec(self):
dial = ImageAdjustmentDialog(self.settings.image)
if dial.exec_():
algorithm = dial.result_val.algorithm
dial2 = ExecuteFunctionDialog(
algorithm.transform, [], {"image": self.settings.image, "arguments": dial.result_val.values}
algorithm.transform,
[],
{"image": self.settings.image, "arguments": dial.result_val.values, "roi_info": self.settings.roi_info},
)
if dial2.exec_():
result: Image = dial2.get_result()
self.settings.set_project_info(self.get_project_info(result.file_path, result))
image, roi_info = dial2.get_result()
self.settings.set_project_info(self.get_project_info(image.file_path, image, roi_info))

def closeEvent(self, event: QCloseEvent):
for el in self.viewer_list:
Expand Down Expand Up @@ -367,6 +369,9 @@ def _screenshot():
return _screenshot

def image_read(self):
if self.settings.image_path is None:
self.setWindowTitle(f"{self.title_base}")
return
folder_name, file_name = os.path.split(self.settings.image_path)
self.setWindowTitle(f"{self.title_base}: {os.path.join(os.path.basename(folder_name), file_name)}")
self.statusBar().showMessage(self.settings.image_path)
Expand Down
5 changes: 3 additions & 2 deletions package/PartSegCore/image_transforming/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from PartSegCore.algorithm_describe_base import Register
from PartSegCore.image_transforming.image_projection import ImageProjection
from PartSegCore.image_transforming.interpolate_image import InterpolateImage
from PartSegCore.image_transforming.swap_time_stack import SwapTimeStack
from PartSegCore.image_transforming.transform_base import TransformBase

image_transform_dict = Register(InterpolateImage, SwapTimeStack)
image_transform_dict = Register(InterpolateImage, SwapTimeStack, ImageProjection)

__all__ = ("image_transform_dict", "TransformBase")
__all__ = ("image_transform_dict", "InterpolateImage", "TransformBase", "ImageProjection")
75 changes: 75 additions & 0 deletions package/PartSegCore/image_transforming/image_projection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from enum import Enum
from typing import Callable, List, Optional, Tuple

import numpy as np
from pydantic import Field

from PartSegCore.image_transforming.transform_base import TransformBase
from PartSegCore.roi_info import ROIInfo
from PartSegCore.utils import BaseModel
from PartSegImage import Image


class ProjectionType(Enum):
MAX = "max"
MIN = "min"
MEAN = "mean"
SUM = "sum"


class ImageProjectionParams(BaseModel):
projection_type: ProjectionType = Field(ProjectionType.MAX, suffix="Mask and ROI projection will allways use max")
keep_mask = False
keep_roi = False


def _calc_target_shape(image: Image):
new_shape = list(image.shape)
new_shape[image.array_axis_order.index("Z")] = 1
return tuple(new_shape)


class ImageProjection(TransformBase):
__argument_class__ = ImageProjectionParams

@classmethod
def transform(
cls,
image: Image,
roi_info: ROIInfo,
arguments: ImageProjectionParams, # type: ignore[override]
callback_function: Optional[Callable[[str, int], None]] = None,
) -> Tuple[Image, Optional[ROIInfo]]:
project_operator = getattr(np, arguments.projection_type.value)
axis = image.array_axis_order.index("Z")
target_shape = _calc_target_shape(image)
spacing = list(image.spacing)
spacing.pop(axis - 1 if image.time_pos < image.stack_pos else axis)
new_channels = [
project_operator(image.get_channel(i), axis=axis).reshape(target_shape) for i in range(image.channels)
]
new_mask = None
if arguments.keep_mask and image.mask is not None:
new_mask = np.max(image.mask, axis=axis).reshape(target_shape)

roi = None
if arguments.keep_roi and roi_info.roi is not None:
roi = ROIInfo(np.max(image.fit_array_to_image(roi_info.roi), axis=axis).reshape(target_shape))
return (
image.__class__(
data=new_channels, image_spacing=tuple(spacing), channel_names=image.channel_names, mask=new_mask
),
roi,
)

@classmethod
def get_fields_per_dimension(cls, component_list: List[str]):
return cls.__argument_class__

@classmethod
def calculate_initial(cls, image: Image):
return cls.get_default_values()

@classmethod
def get_name(cls) -> str:
return "Image Projection"
20 changes: 13 additions & 7 deletions package/PartSegCore/image_transforming/interpolate_image.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Tuple, Union

from scipy.ndimage import zoom

from PartSegCore.algorithm_describe_base import AlgorithmProperty
from PartSegCore.image_transforming.transform_base import TransformBase
from PartSegCore.roi_info import ROIInfo
from PartSegImage import Image


Expand All @@ -13,9 +14,10 @@ def get_fields(cls):
return ["It can be very slow.", AlgorithmProperty("scale", "Scale", 1.0)]

@classmethod
def get_fields_per_dimension(cls, component_list: List[str]):
return ["it can be very slow"] + [
AlgorithmProperty(f"scale_{i.lower()}", f"Scale {i}", 1.0) for i in reversed(component_list)
def get_fields_per_dimension(cls, component_list: List[str]) -> List[Union[str, AlgorithmProperty]]:
return [
"it can be very slow",
*[AlgorithmProperty(f"scale_{i.lower()}", f"Scale {i}", 1.0) for i in reversed(component_list)],
]

@classmethod
Expand All @@ -24,8 +26,12 @@ def get_name(cls):

@classmethod
def transform(
cls, image: Image, arguments: dict, callback_function: Optional[Callable[[str, int], None]] = None
) -> Image:
cls,
image: Image,
roi_info: Optional[ROIInfo],
arguments: dict,
callback_function: Optional[Callable[[str, int], None]] = None,
) -> Tuple[Image, Optional[ROIInfo]]:
keys = [x for x in arguments if x.startswith("scale")]
keys_order = Image.axis_order.lower()
scale_factor = [1.0] * len(keys_order)
Expand All @@ -46,7 +52,7 @@ def transform(
mask = zoom(image.mask, scale_factor[:-1], mode="mirror")
else:
mask = None
return image.substitute(data=array, image_spacing=spacing, mask=mask)
return image.substitute(data=array, image_spacing=spacing, mask=mask), None

@classmethod
def calculate_initial(cls, image: Image):
Expand Down
11 changes: 8 additions & 3 deletions package/PartSegCore/image_transforming/swap_time_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@

from PartSegCore.algorithm_describe_base import AlgorithmProperty
from PartSegCore.image_transforming.transform_base import TransformBase
from PartSegCore.roi_info import ROIInfo
from PartSegImage import Image


class SwapTimeStack(TransformBase):
@classmethod
def transform(
cls, image: Image, arguments: dict, callback_function: typing.Optional[typing.Callable[[str, int], None]] = None
) -> Image:
return image.swap_time_and_stack()
cls,
image: Image,
roi_info: ROIInfo,
arguments: dict,
callback_function: typing.Optional[typing.Callable[[str, int], None]] = None,
) -> typing.Tuple[Image, typing.Optional[ROIInfo]]:
return image.swap_time_and_stack(), None

@classmethod
def get_fields_per_dimension(cls, component_list: typing.List[str]):
Expand Down
15 changes: 10 additions & 5 deletions package/PartSegCore/image_transforming/transform_base.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
from abc import ABC
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Tuple, Union

from PartSegCore.algorithm_describe_base import AlgorithmDescribeBase
from PartSegCore.algorithm_describe_base import AlgorithmDescribeBase, AlgorithmProperty
from PartSegCore.roi_info import ROIInfo
from PartSegImage import Image


class TransformBase(AlgorithmDescribeBase, ABC):
@classmethod
def transform(
cls, image: Image, arguments: dict, callback_function: Optional[Callable[[str, int], None]] = None
) -> Image:
cls,
image: Image,
roi_info: ROIInfo,
arguments: dict,
callback_function: Optional[Callable[[str, int], None]] = None,
) -> Tuple[Image, Optional[ROIInfo]]:
raise NotImplementedError

@classmethod
def get_fields_per_dimension(cls, component_list: List[str]):
def get_fields_per_dimension(cls, component_list: List[str]) -> List[Union[str, AlgorithmProperty]]:
raise NotImplementedError

@classmethod
Expand Down
27 changes: 20 additions & 7 deletions package/tests/test_PartSeg/roi_analysis/test_main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@
import pytest
from qtpy.QtCore import Qt

from PartSeg._roi_analysis import main_window as analysis_main_window
from PartSeg._roi_analysis.main_window import ChannelProperty, MainWindow, Options
from PartSegCore.analysis import ProjectTuple
from PartSegCore.roi_info import ROIInfo
from PartSegCore.segmentation import ROIExtractionResult


class TestAnalysisMainWindow:
# @pytest.mark.skipif((platform.system() == "Linux") and CI_BUILD, reason="debug test fail")
@pytest.mark.pyside_skip()
def test_opening(self, qtbot, tmpdir):
main_window = analysis_main_window.MainWindow(tmpdir, initial_image=False)
main_window = MainWindow(tmpdir, initial_image=False)
qtbot.addWidget(main_window)
main_window.main_menu.batch_processing_btn.click()
main_window.main_menu.advanced_btn.click()
Expand All @@ -22,29 +24,40 @@ def test_opening(self, qtbot, tmpdir):

@pytest.mark.pyside_skip()
def test_change_theme(self, qtbot, tmpdir):
main_window = analysis_main_window.MainWindow(tmpdir, initial_image=False)
main_window = MainWindow(tmpdir, initial_image=False)
qtbot.addWidget(main_window)
assert main_window.raw_image.viewer.theme == "light"
main_window.settings.theme_name = "dark"
assert main_window.raw_image.viewer.theme == "dark"

@pytest.mark.pyside_skip()
def test_scale_bar(self, qtbot, tmpdir):
main_window = analysis_main_window.MainWindow(tmpdir, initial_image=False)
main_window = MainWindow(tmpdir, initial_image=False)
qtbot.addWidget(main_window)
main_window._scale_bar_warning = False
assert not main_window.result_image.viewer.scale_bar.visible
main_window._toggle_scale_bar()
assert main_window.result_image.viewer.scale_bar.visible

def test_get_project_info(self, image, tmp_path):
res = MainWindow.get_project_info(str(tmp_path / "test.tiff"), image)
assert isinstance(res.roi_info, ROIInfo)
assert isinstance(res, ProjectTuple)

roi = np.zeros(image.shape, dtype=np.uint8)
roi[:, 2:-2] = 1
res = MainWindow.get_project_info(str(tmp_path / "test.tiff"), image, ROIInfo(roi))
assert isinstance(res.roi_info, ROIInfo)
assert set(res.roi_info.bound_info) == {1}


@pytest.fixture()
def analysis_options(qtbot, part_settings):
ch_property = analysis_main_window.ChannelProperty(part_settings, "test")
ch_property = ChannelProperty(part_settings, "test")
qtbot.addWidget(ch_property)
left_image = MagicMock()
synchronize = MagicMock()
options = analysis_main_window.Options(part_settings, ch_property, left_image, synchronize)
options = Options(part_settings, ch_property, left_image, synchronize)
qtbot.addWidget(options)
qtbot.addWidget(options.compare_btn)
return options
Expand Down Expand Up @@ -143,7 +156,7 @@ def test_execution_done(self, analysis_options, part_settings, monkeypatch):
info_mock = Mock()

monkeypatch.setattr(analysis_options, "sender", mock)
monkeypatch.setattr(analysis_main_window.QMessageBox, "information", info_mock)
monkeypatch.setattr("qtpy.QtWidgets.QMessageBox.information", info_mock)

res = ROIExtractionResult(
roi=np.zeros(part_settings.image.shape, dtype="uint8"),
Expand Down
Loading

0 comments on commit 7fd9822

Please sign in to comment.