Skip to content

Commit

Permalink
add support for slicing imagelike objects
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzocerrone committed Sep 26, 2024
1 parent 13d6d91 commit 95dd94d
Show file tree
Hide file tree
Showing 8 changed files with 324 additions and 30 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ dependencies = [
"pydantic",
"requests",
"aiohttp",
"dask[array]"
"dask[complete]",
]

# https://peps.python.org/pep-0621/#dependencies-optional-dependencies
Expand Down
219 changes: 215 additions & 4 deletions src/ngio/core/image_like_handler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
"""Generic class to handle Image-like data in a OME-NGFF file."""

from pathlib import Path
from typing import Literal
from warnings import warn

import dask.array as da
import numpy as np
import zarr
from dask.delayed import Delayed
from dask.distributed import Lock

from ngio._common_types import ArrayLike
from ngio.core.dimensions import Dimensions
from ngio.core.roi import WorldCooROI
from ngio.io import StoreOrGroup, open_group_wrapper
from ngio.ngff_meta import (
Dataset,
Expand All @@ -15,7 +21,7 @@
SpaceUnits,
get_ngff_image_meta_handler,
)
from ngio.pipes import DataTransformPipe, NaiveSlicer
from ngio.pipes import DataTransformPipe, NaiveSlicer, RoiSlicer


class ImageLike:
Expand Down Expand Up @@ -50,6 +56,9 @@ def __init__(
meta_mode (str): The mode of the metadata handler.
cache (bool): Whether to cache the metadata.
"""
if not strict:
warn("Strict mode is not fully supported yet.", UserWarning, stacklevel=2)

if not isinstance(store, zarr.Group):
store = open_group_wrapper(store=store, mode="r+")

Expand All @@ -68,7 +77,15 @@ def __init__(
highest_resolution=highest_resolution,
strict=strict,
)

if pixel_size is not None:
pixel_size.virtual = True
self._virtual_pixel_size = pixel_size
else:
self._virtual_pixel_size = None

self._init_dataset(dataset)
self._dask_lock = None

def _init_dataset(self, dataset: Dataset):
"""Set the dataset of the image.
Expand Down Expand Up @@ -145,6 +162,9 @@ def space_axes_unit(self) -> SpaceUnits:
@property
def pixel_size(self) -> PixelSize:
"""Return the pixel resolution of the image."""
if self._virtual_pixel_size is not None:
return self._virtual_pixel_size

return self.dataset.pixel_size

# Method to get the data of the image
Expand All @@ -153,6 +173,11 @@ def array(self) -> zarr.Array:
"""Return the image data as a Zarr array."""
return self._array

@property
def dask_array(self) -> da.core.Array:
"""Return the image data as a Dask array."""
return da.from_zarr(self.array)

@property
def dimensions(self) -> Dimensions:
"""Return the dimensions of the image."""
Expand All @@ -168,6 +193,192 @@ def on_disk_shape(self) -> tuple[int, ...]:
"""Return the shape of the image."""
return self.dimensions.on_disk_shape

def get_data(self) -> np.ndarray:
"""Return the image data as a Zarr array."""
return self.array[...]
def init_lock(self, lock_id: str | None = None) -> None:
"""Set the lock for the Dask array."""
# Unique zarr array identifier
array_path = (
Path(self._group.store.path) / self._group.path / self._dataset.path
)
lock_id = f"Zarr_IO_Lock_{array_path}" if lock_id is None else lock_id
self._dask_lock = Lock(lock_id)

def _get_pipe(
self,
data_pipe: DataTransformPipe,
mode: Literal["numpy", "dask"] = "numpy",
) -> ArrayLike:
"""Return the data transform pipe."""
if mode == "numpy":
return data_pipe.get(data=self.array)
elif mode == "dask":
return data_pipe.get(data=self.dask_array)
if self._dask_lock is None:
return data_pipe.get(data=self.dask_array)

with self._dask_lock:
patch = data_pipe.get(data=self.dask_array)
return patch
else:
raise ValueError(f"Invalid mode {mode}")

def _set_pipe(
self,
data_pipe: DataTransformPipe,
patch: ArrayLike,
) -> None:
"""Set the data transform pipe."""
if isinstance(patch, np.ndarray):
data_pipe.set(data=self.array, patch=patch)

elif isinstance(patch, (da.core.Array, Delayed)): # noqa: UP038
if self._dask_lock is None:
return data_pipe.set(data=self.array, patch=patch)

array = self.array
with self._dask_lock:
data_pipe.set(data=array, patch=patch)
else:
raise ValueError(
f"Invalid patch type {type(patch)}. "
"Supported types are np.ndarray and da.core.Array"
)

def _build_roi_pipe(
self,
roi: WorldCooROI,
t: int | slice | None = None,
c: int | slice | None = None,
preserve_dimensions: bool = False,
) -> DataTransformPipe:
"""Build the data transform pipe for a region of interest (ROI)."""
roi_coo = roi.to_raster_coo(
pixel_size=self.dataset.pixel_size, dimensions=self.dimensions
)
slicer = RoiSlicer(
on_disk_axes_name=self.dataset.on_disk_axes_names,
axes_order=self.dataset.axes_order,
roi=roi_coo,
t=t,
c=c,
preserve_dimensions=preserve_dimensions,
)
return DataTransformPipe(slicer=slicer)

def _build_naive_pipe(
self,
x: int | slice | None = None,
y: int | slice | None = None,
z: int | slice | None = None,
t: int | slice | None = None,
c: int | slice | None = None,
preserve_dimensions: bool = False,
) -> DataTransformPipe:
"""Build the data transform pipe for a naive slice."""
slicer = NaiveSlicer(
on_disk_axes_name=self.dataset.on_disk_axes_names,
axes_order=self.dataset.axes_order,
x=x,
y=y,
z=z,
t=t,
c=c,
preserve_dimensions=preserve_dimensions,
)
return DataTransformPipe(slicer=slicer)

def get_data_from_roi(
self,
roi: WorldCooROI,
t: int | slice | None = None,
c: int | slice | None = None,
mode: Literal["numpy", "dask"] = "numpy",
preserve_dimensions: bool = False,
) -> ArrayLike | tuple[ArrayLike, DataTransformPipe]:
"""Return the image data from a region of interest (ROI).
Args:
roi (WorldCooROI): The region of interest.
t (int | slice | None): The time index or slice.
c (int | slice | None): The channel index or slice.
mode (str): The mode to return the data.
preserve_dimensions (bool): Whether to preserve the dimensions of the data.
"""
data_pipe = self._build_roi_pipe(
roi=roi, t=t, c=c, preserve_dimensions=preserve_dimensions
)
return_array = self._get_pipe(data_pipe=data_pipe, mode=mode)
return return_array

def set_data_from_roi(
self,
roi: WorldCooROI,
patch: ArrayLike,
t: int | slice | None = None,
c: int | slice | None = None,
preserve_dimensions: bool = False,
) -> None:
"""Set the image data from a region of interest (ROI).
Args:
roi (WorldCooROI): The region of interest.
patch (ArrayLike): The patch to set.
t (int | slice | None): The time index or slice.
c (int | slice | None): The channel index or slice.
preserve_dimensions (bool): Whether to preserve the dimensions of the data.
"""
data_pipe = self._build_roi_pipe(
roi=roi, t=t, c=c, preserve_dimensions=preserve_dimensions
)
self._set_pipe(data_pipe=data_pipe, patch=patch)

def get_data(
self,
x: int | slice | None = None,
y: int | slice | None = None,
z: int | slice | None = None,
t: int | slice | None = None,
c: int | slice | None = None,
mode: Literal["numpy", "dask"] = "numpy",
preserve_dimensions: bool = False,
) -> ArrayLike:
"""Return the image data.
Args:
x (int | slice | None): The x index or slice.
y (int | slice | None): The y index or slice.
z (int | slice | None): The z index or slice.
t (int | slice | None): The time index or slice.
c (int | slice | None): The channel index or slice.
mode (str): The mode to return the data.
preserve_dimensions (bool): Whether to preserve the dimensions of the data.
"""
data_pipe = self._build_naive_pipe(
x=x, y=y, z=z, t=t, c=c, preserve_dimensions=preserve_dimensions
)
return self._get_pipe(data_pipe=data_pipe, mode=mode)

def set_data(
self,
patch: ArrayLike,
x: int | slice | None = None,
y: int | slice | None = None,
z: int | slice | None = None,
t: int | slice | None = None,
c: int | slice | None = None,
preserve_dimensions: bool = False,
) -> None:
"""Set the image data in the zarr array.
Args:
patch (ArrayLike): The patch to set.
x (int | slice | None): The x index or slice.
y (int | slice | None): The y index or slice.
z (int | slice | None): The z index or slice.
t (int | slice | None): The time index or slice.
c (int | slice | None): The channel index or slice.
preserve_dimensions (bool): Whether to preserve the dimensions of the data.
"""
data_pipe = self._build_naive_pipe(
x=x, y=y, z=z, t=t, c=c, preserve_dimensions=preserve_dimensions
)
self._set_pipe(data_pipe=data_pipe, patch=patch)
5 changes: 4 additions & 1 deletion src/ngio/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
def create_empty_ome_zarr_image(
store: StoreLike,
shape: list[int],
chunks: list[int] | None = None,
dtype: str = "uint16",
on_disk_axis: list[str] = ("t", "c", "z", "y", "x"),
pixel_sizes: PixelSize | None = None,
Expand Down Expand Up @@ -87,7 +88,9 @@ def create_empty_ome_zarr_image(
for dataset in image_meta.datasets:
path = dataset.path

group.create_array(name=path, fill_value=0, shape=shape, dtype=dtype)
group.create_array(
name=path, fill_value=0, shape=shape, dtype=dtype, chunks=chunks
)

# Todo redo this with when a proper build of pyramid id implemente
shape = [int(s / sc) for s, sc in zip(shape, scaling_factor, strict=True)]
1 change: 1 addition & 0 deletions src/ngio/ngff_meta/fractal_image_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class PixelSize(BaseModel):
y: float = Field(..., ge=0)
z: float = Field(1.0, ge=0)
unit: SpaceUnits = SpaceUnits.micrometer
virtual: bool = False

@classmethod
def from_list(cls, sizes: list[float], unit: SpaceUnits):
Expand Down
4 changes: 2 additions & 2 deletions src/ngio/pipes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""A module to handle data transforms for image data."""

from ngio.pipes._slicer_transforms import NaiveSlicer
from ngio.pipes._slicer_transforms import NaiveSlicer, RoiSlicer
from ngio.pipes.data_pipe import DataTransformPipe

__all__ = ["DataTransformPipe", "NaiveSlicer"]
__all__ = ["DataTransformPipe", "NaiveSlicer", "RoiSlicer"]
Loading

0 comments on commit 95dd94d

Please sign in to comment.