Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add option to create projection alongside z-axis #919

Merged
merged 11 commits into from
Apr 1, 2023
Merged
6 changes: 4 additions & 2 deletions package/PartSeg/_roi_analysis/main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,8 +661,10 @@ def closeEvent(self, event):
super().closeEvent(event)

@staticmethod
def get_project_info(file_path, image):
return ProjectTuple(file_path=file_path, image=image)
def get_project_info(file_path, image, roi_info=None):
if roi_info is None:
roi_info = ROIInfo(None)
return ProjectTuple(file_path=file_path, image=image, roi_info=roi_info)

def set_data(self, data):
self.main_menu.set_data(data)
Expand Down
11 changes: 9 additions & 2 deletions package/PartSeg/_roi_mask/main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,8 +986,15 @@ def closeEvent(self, event: QCloseEvent):
super().closeEvent(event)

@staticmethod
def get_project_info(file_path, image):
return MaskProjectTuple(file_path=file_path, image=image)
def get_project_info(file_path, image, roi_info=None):
if roi_info is None:
roi_info = ROIInfo(None)
return MaskProjectTuple(
file_path=file_path,
image=image,
roi_info=roi_info,
roi_extraction_parameters={i: None for i in roi_info.bound_info},
)

def set_data(self, data):
self.main_menu.set_data(data)
Expand Down
1 change: 1 addition & 0 deletions package/PartSeg/common_gui/image_adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(self, image: Image, transform_dict: Dict[str, TransformBase] = None
layout.addWidget(self.cancel_btn, 2, 0)
layout.addWidget(self.process_btn, 2, 2)
self.setLayout(layout)
self.setWindowTitle("Image adjustment")

def process(self):
values = self.stacked.currentWidget().get_values()
Expand Down
13 changes: 9 additions & 4 deletions package/PartSeg/common_gui/main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,19 +326,21 @@ def show_about_dialog():
AboutDialog().exec_()

@staticmethod
def get_project_info(file_path, image):
def get_project_info(file_path, image, roi_info=None):
raise NotImplementedError

def image_adjust_exec(self):
dial = ImageAdjustmentDialog(self.settings.image)
if dial.exec_():
algorithm = dial.result_val.algorithm
dial2 = ExecuteFunctionDialog(
algorithm.transform, [], {"image": self.settings.image, "arguments": dial.result_val.values}
algorithm.transform,
[],
{"image": self.settings.image, "arguments": dial.result_val.values, "roi_info": self.settings.roi_info},
)
if dial2.exec_():
result: Image = dial2.get_result()
self.settings.set_project_info(self.get_project_info(result.file_path, result))
image, roi_info = dial2.get_result()
self.settings.set_project_info(self.get_project_info(image.file_path, image, roi_info))

def closeEvent(self, event: QCloseEvent):
for el in self.viewer_list:
Expand Down Expand Up @@ -367,6 +369,9 @@ def _screenshot():
return _screenshot

def image_read(self):
if self.settings.image_path is None:
self.setWindowTitle(f"{self.title_base}")
return
folder_name, file_name = os.path.split(self.settings.image_path)
self.setWindowTitle(f"{self.title_base}: {os.path.join(os.path.basename(folder_name), file_name)}")
self.statusBar().showMessage(self.settings.image_path)
Expand Down
5 changes: 3 additions & 2 deletions package/PartSegCore/image_transforming/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from PartSegCore.algorithm_describe_base import Register
from PartSegCore.image_transforming.image_projection import ImageProjection
from PartSegCore.image_transforming.interpolate_image import InterpolateImage
from PartSegCore.image_transforming.swap_time_stack import SwapTimeStack
from PartSegCore.image_transforming.transform_base import TransformBase

image_transform_dict = Register(InterpolateImage, SwapTimeStack)
image_transform_dict = Register(InterpolateImage, SwapTimeStack, ImageProjection)

__all__ = ("image_transform_dict", "TransformBase")
__all__ = ("image_transform_dict", "InterpolateImage", "TransformBase", "ImageProjection")
75 changes: 75 additions & 0 deletions package/PartSegCore/image_transforming/image_projection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from enum import Enum
from typing import Callable, List, Optional, Tuple

import numpy as np
from pydantic import Field

from PartSegCore.image_transforming.transform_base import TransformBase
from PartSegCore.roi_info import ROIInfo
from PartSegCore.utils import BaseModel
from PartSegImage import Image


class ProjectionType(Enum):
MAX = "max"
MIN = "min"
MEAN = "mean"
SUM = "sum"


class ImageProjectionParams(BaseModel):
projection_type: ProjectionType = Field(ProjectionType.MAX, suffix="Mask and ROI projection will allways use max")
keep_mask = False
keep_roi = False


def _calc_target_shape(image: Image):
new_shape = list(image.shape)
new_shape[image.array_axis_order.index("Z")] = 1
return tuple(new_shape)


class ImageProjection(TransformBase):
__argument_class__ = ImageProjectionParams

@classmethod
def transform(
cls,
image: Image,
roi_info: ROIInfo,
arguments: ImageProjectionParams, # type: ignore[override]
callback_function: Optional[Callable[[str, int], None]] = None,
) -> Tuple[Image, Optional[ROIInfo]]:
project_operator = getattr(np, arguments.projection_type.value)
axis = image.array_axis_order.index("Z")
target_shape = _calc_target_shape(image)
spacing = list(image.spacing)
spacing.pop(axis - 1 if image.time_pos < image.stack_pos else axis)
new_channels = [
project_operator(image.get_channel(i), axis=axis).reshape(target_shape) for i in range(image.channels)
]
new_mask = None
if arguments.keep_mask and image.mask is not None:
new_mask = np.max(image.mask, axis=axis).reshape(target_shape)

roi = None
if arguments.keep_roi and roi_info.roi is not None:
roi = ROIInfo(np.max(image.fit_array_to_image(roi_info.roi), axis=axis).reshape(target_shape))
return (
image.__class__(
data=new_channels, image_spacing=tuple(spacing), channel_names=image.channel_names, mask=new_mask
),
roi,
)

@classmethod
def get_fields_per_dimension(cls, component_list: List[str]):
return cls.__argument_class__

@classmethod
def calculate_initial(cls, image: Image):
return cls.get_default_values()

@classmethod
def get_name(cls) -> str:
return "Image Projection"
20 changes: 13 additions & 7 deletions package/PartSegCore/image_transforming/interpolate_image.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Tuple, Union

from scipy.ndimage import zoom

from PartSegCore.algorithm_describe_base import AlgorithmProperty
from PartSegCore.image_transforming.transform_base import TransformBase
from PartSegCore.roi_info import ROIInfo
from PartSegImage import Image


Expand All @@ -13,9 +14,10 @@ def get_fields(cls):
return ["It can be very slow.", AlgorithmProperty("scale", "Scale", 1.0)]

@classmethod
def get_fields_per_dimension(cls, component_list: List[str]):
return ["it can be very slow"] + [
AlgorithmProperty(f"scale_{i.lower()}", f"Scale {i}", 1.0) for i in reversed(component_list)
def get_fields_per_dimension(cls, component_list: List[str]) -> List[Union[str, AlgorithmProperty]]:
return [
"it can be very slow",
*[AlgorithmProperty(f"scale_{i.lower()}", f"Scale {i}", 1.0) for i in reversed(component_list)],
]

@classmethod
Expand All @@ -24,8 +26,12 @@ def get_name(cls):

@classmethod
def transform(
cls, image: Image, arguments: dict, callback_function: Optional[Callable[[str, int], None]] = None
) -> Image:
cls,
image: Image,
roi_info: Optional[ROIInfo],
arguments: dict,
callback_function: Optional[Callable[[str, int], None]] = None,
) -> Tuple[Image, Optional[ROIInfo]]:
keys = [x for x in arguments if x.startswith("scale")]
keys_order = Image.axis_order.lower()
scale_factor = [1.0] * len(keys_order)
Expand All @@ -46,7 +52,7 @@ def transform(
mask = zoom(image.mask, scale_factor[:-1], mode="mirror")
else:
mask = None
return image.substitute(data=array, image_spacing=spacing, mask=mask)
return image.substitute(data=array, image_spacing=spacing, mask=mask), None

@classmethod
def calculate_initial(cls, image: Image):
Expand Down
11 changes: 8 additions & 3 deletions package/PartSegCore/image_transforming/swap_time_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@

from PartSegCore.algorithm_describe_base import AlgorithmProperty
from PartSegCore.image_transforming.transform_base import TransformBase
from PartSegCore.roi_info import ROIInfo
from PartSegImage import Image


class SwapTimeStack(TransformBase):
@classmethod
def transform(
cls, image: Image, arguments: dict, callback_function: typing.Optional[typing.Callable[[str, int], None]] = None
) -> Image:
return image.swap_time_and_stack()
cls,
image: Image,
roi_info: ROIInfo,
arguments: dict,
callback_function: typing.Optional[typing.Callable[[str, int], None]] = None,
) -> typing.Tuple[Image, typing.Optional[ROIInfo]]:
return image.swap_time_and_stack(), None

@classmethod
def get_fields_per_dimension(cls, component_list: typing.List[str]):
Expand Down
15 changes: 10 additions & 5 deletions package/PartSegCore/image_transforming/transform_base.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
from abc import ABC
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Tuple, Union

from PartSegCore.algorithm_describe_base import AlgorithmDescribeBase
from PartSegCore.algorithm_describe_base import AlgorithmDescribeBase, AlgorithmProperty
from PartSegCore.roi_info import ROIInfo
from PartSegImage import Image


class TransformBase(AlgorithmDescribeBase, ABC):
@classmethod
def transform(
cls, image: Image, arguments: dict, callback_function: Optional[Callable[[str, int], None]] = None
) -> Image:
cls,
image: Image,
roi_info: ROIInfo,
arguments: dict,
callback_function: Optional[Callable[[str, int], None]] = None,
) -> Tuple[Image, Optional[ROIInfo]]:
raise NotImplementedError

@classmethod
def get_fields_per_dimension(cls, component_list: List[str]):
def get_fields_per_dimension(cls, component_list: List[str]) -> List[Union[str, AlgorithmProperty]]:
raise NotImplementedError

@classmethod
Expand Down
27 changes: 20 additions & 7 deletions package/tests/test_PartSeg/roi_analysis/test_main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@
import pytest
from qtpy.QtCore import Qt

from PartSeg._roi_analysis import main_window as analysis_main_window
from PartSeg._roi_analysis.main_window import ChannelProperty, MainWindow, Options
from PartSegCore.analysis import ProjectTuple
from PartSegCore.roi_info import ROIInfo
from PartSegCore.segmentation import ROIExtractionResult


class TestAnalysisMainWindow:
# @pytest.mark.skipif((platform.system() == "Linux") and CI_BUILD, reason="debug test fail")
@pytest.mark.pyside_skip()
def test_opening(self, qtbot, tmpdir):
main_window = analysis_main_window.MainWindow(tmpdir, initial_image=False)
main_window = MainWindow(tmpdir, initial_image=False)
qtbot.addWidget(main_window)
main_window.main_menu.batch_processing_btn.click()
main_window.main_menu.advanced_btn.click()
Expand All @@ -22,29 +24,40 @@ def test_opening(self, qtbot, tmpdir):

@pytest.mark.pyside_skip()
def test_change_theme(self, qtbot, tmpdir):
main_window = analysis_main_window.MainWindow(tmpdir, initial_image=False)
main_window = MainWindow(tmpdir, initial_image=False)
qtbot.addWidget(main_window)
assert main_window.raw_image.viewer.theme == "light"
main_window.settings.theme_name = "dark"
assert main_window.raw_image.viewer.theme == "dark"

@pytest.mark.pyside_skip()
def test_scale_bar(self, qtbot, tmpdir):
main_window = analysis_main_window.MainWindow(tmpdir, initial_image=False)
main_window = MainWindow(tmpdir, initial_image=False)
qtbot.addWidget(main_window)
main_window._scale_bar_warning = False
assert not main_window.result_image.viewer.scale_bar.visible
main_window._toggle_scale_bar()
assert main_window.result_image.viewer.scale_bar.visible

def test_get_project_info(self, image, tmp_path):
res = MainWindow.get_project_info(str(tmp_path / "test.tiff"), image)
assert isinstance(res.roi_info, ROIInfo)
assert isinstance(res, ProjectTuple)

roi = np.zeros(image.shape, dtype=np.uint8)
roi[:, 2:-2] = 1
res = MainWindow.get_project_info(str(tmp_path / "test.tiff"), image, ROIInfo(roi))
assert isinstance(res.roi_info, ROIInfo)
assert set(res.roi_info.bound_info) == {1}


@pytest.fixture()
def analysis_options(qtbot, part_settings):
ch_property = analysis_main_window.ChannelProperty(part_settings, "test")
ch_property = ChannelProperty(part_settings, "test")
qtbot.addWidget(ch_property)
left_image = MagicMock()
synchronize = MagicMock()
options = analysis_main_window.Options(part_settings, ch_property, left_image, synchronize)
options = Options(part_settings, ch_property, left_image, synchronize)
qtbot.addWidget(options)
qtbot.addWidget(options.compare_btn)
return options
Expand Down Expand Up @@ -143,7 +156,7 @@ def test_execution_done(self, analysis_options, part_settings, monkeypatch):
info_mock = Mock()

monkeypatch.setattr(analysis_options, "sender", mock)
monkeypatch.setattr(analysis_main_window.QMessageBox, "information", info_mock)
monkeypatch.setattr("qtpy.QtWidgets.QMessageBox.information", info_mock)

res = ROIExtractionResult(
roi=np.zeros(part_settings.image.shape, dtype="uint8"),
Expand Down
Loading