diff --git a/package/PartSegCore/napari_plugins/loader.py b/package/PartSegCore/napari_plugins/loader.py index 35e1961cb..3a47ab9b9 100644 --- a/package/PartSegCore/napari_plugins/loader.py +++ b/package/PartSegCore/napari_plugins/loader.py @@ -51,13 +51,13 @@ def add_color(image: Image, idx: int) -> dict: # noqa: ARG001 return {} -def _image_to_layers(project_info, scale, translate): +def _image_to_layers(project_info, scale, translate, shear): res_layers = [] if project_info.image.name == "ROI" and project_info.image.channels == 1: res_layers.append( ( project_info.image.get_channel(0), - {"scale": scale, "name": project_info.image.channel_names[0], "translate": translate}, + {"scale": scale, "name": project_info.image.channel_names[0], "translate": translate, "shear": shear}, "labels", ) ) @@ -71,6 +71,7 @@ def _image_to_layers(project_info, scale, translate): "blending": "additive", "translate": translate, "metadata": project_info.image.metadata, + "shear": shear, **add_color(project_info.image, i), }, "image", @@ -85,14 +86,15 @@ def project_to_layers(project_info: typing.Union[ProjectTuple, MaskProjectTuple] res_layers = [] if project_info.image is not None and not isinstance(project_info.image, str): scale = project_info.image.normalized_scaling() + shear = project_info.image.shear translate = project_info.image.shift translate = (0,) * (len(project_info.image.axis_order.replace("C", "")) - len(translate)) + translate - res_layers.extend(_image_to_layers(project_info, scale, translate)) + res_layers.extend(_image_to_layers(project_info, scale, translate, shear)) if project_info.roi_info.roi is not None: res_layers.append( ( project_info.image.fit_array_to_image(project_info.roi_info.roi), - {"scale": scale, "name": "ROI", "translate": translate}, + {"scale": scale, "name": "ROI", "translate": translate, "shear": shear}, "labels", ) ) @@ -105,6 +107,7 @@ def project_to_layers(project_info: typing.Union[ProjectTuple, MaskProjectTuple] "name": name, "translate": translate, "visible": False, + "shear": shear, }, "labels", ) @@ -115,7 +118,7 @@ def project_to_layers(project_info: typing.Union[ProjectTuple, MaskProjectTuple] res_layers.append( ( project_info.image.fit_array_to_image(project_info.mask), - {"scale": scale, "name": "Mask", "translate": translate}, + {"scale": scale, "name": "Mask", "translate": translate, "shear": shear}, "labels", ) ) diff --git a/package/PartSegImage/image.py b/package/PartSegImage/image.py index 73eca3b11..e5085da99 100644 --- a/package/PartSegImage/image.py +++ b/package/PartSegImage/image.py @@ -224,6 +224,7 @@ def __init__( channel_info: list[ChannelInfo | ChannelInfoFull] | None = None, axes_order: str | None = None, shift: Spacing | None = None, + shear: np.ndarray | None = None, name: str = "", metadata_dict: dict | None = None, ): @@ -243,6 +244,7 @@ def __init__( self._image_spacing = tuple(el if el > 0 else 10**-6 for el in self._image_spacing) self._shift = tuple(shift) if shift is not None else (0,) * len(self._image_spacing) + self._shear = shear self.name = name self.file_path = file_path @@ -369,6 +371,10 @@ def _merge_channel_names(base_channel_names: list[str], new_channel_names: list[ base_channel_names.append(new_name) return base_channel_names + @property + def shear(self) -> np.ndarray | None: + return self._shear + @property def channel_info(self) -> list[ChannelInfoFull]: return [copy(x) for x in self._channel_info] diff --git a/package/PartSegImage/image_reader.py b/package/PartSegImage/image_reader.py index f97c8bcb2..b2fb53c00 100644 --- a/package/PartSegImage/image_reader.py +++ b/package/PartSegImage/image_reader.py @@ -1,6 +1,7 @@ import inspect import os.path import typing +import warnings from abc import abstractmethod from contextlib import suppress from importlib.metadata import version @@ -389,8 +390,36 @@ def read(self, image_path: typing.Union[str, BytesIO, Path], mask_path=None, ext axes_order=self.return_order(), metadata_dict=metadata, channel_info=self._get_channel_info(), + shear=self._read_shear(metadata), ) + def _read_shear(self, metadata: dict): + skew = self._read_skew(metadata) + shear = np.diag([1.0] * len(skew)) + for i, val in enumerate(skew): + if val == 0: + continue + shear[i, i + 1] = np.tan(np.radians(val)) + return shear + + @staticmethod + def _read_skew(metadata: dict): + dimensions = metadata["ImageDocument"]["Metadata"]["Information"]["Image"]["Dimensions"] + res = [0.0] * 4 + for i, dim in enumerate("TZYX"): + if dim not in dimensions: + continue + if f"{dim}AxisShear" not in dimensions[dim]: + continue + + shear_value = dimensions[dim][f"{dim}AxisShear"] + if not shear_value.startswith("Skew"): + warnings.warn(f"Unknown shear value {shear_value}", stacklevel=1) + continue + res[i] = float(shear_value[4:]) + + return res + @classmethod def update_array_shape(cls, array: np.ndarray, axes: str): if "B" in axes: