Skip to content

Commit

Permalink
🚸 Handle Multidimensional Features in WSIReader (#742)
Browse files Browse the repository at this point in the history
- Adds support for multichannel inputs e.g., image masks, feature maps etc. Fixes #609 and #610 

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 29, 2024
1 parent 027ffdd commit 9c51582
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 6 deletions.
33 changes: 32 additions & 1 deletion tests/test_wsireader.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
}
RNG = np.random.default_rng() # Numpy Random Generator


# -------------------------------------------------------------------------------------
# Utility Test Functions
# -------------------------------------------------------------------------------------
Expand Down Expand Up @@ -2110,7 +2111,6 @@ def test_store_reader_alpha(remote_sample: Callable) -> None:
wsi_reader.info,
base_wsi=wsi_reader,
)
store_reader.renderer.info["mpp"] = store_reader.info.as_dict()["mpp"]
wsi_thumb = wsi_reader.slide_thumbnail()
wsi_tile = wsi_reader.read_rect((500, 500), (1000, 1000))
store_thumb = store_reader.slide_thumbnail()
Expand Down Expand Up @@ -2759,3 +2759,34 @@ def test_file_path_does_not_exist() -> None:
def test_read_mpp(wsi: WSIReader) -> None:
"""Test that the mpp is read correctly."""
assert wsi.info.mpp == pytest.approx(0.25, 1)


def test_read_multi_channel(source_image: Path) -> None:
"""Test reading image with more than three channels.
Create a virtual WSI by concatenating the source_image.
"""
img_array = utils.misc.imread(Path(source_image))
new_img_array = np.concatenate((img_array, img_array), axis=-1)

new_img_size = new_img_array.shape[:2][::-1]
meta = wsireader.WSIMeta(slide_dimensions=new_img_size, axes="YXS", mpp=(0.5, 0.5))
wsi = wsireader.VirtualWSIReader(new_img_array, info=meta)

region = wsi.read_rect(
location=(0, 0),
size=(50, 100),
pad_mode="reflect",
units="mpp",
resolution=0.25,
)
target = cv2.resize(
new_img_array[:50, :25, :],
(50, 100),
interpolation=cv2.INTER_CUBIC,
)

assert region.shape == (100, 50, (new_img_array.shape[-1]))
assert np.abs(np.median(region.astype(int) - target.astype(int))) == 0
assert np.abs(np.mean(region.astype(int) - target.astype(int))) < 0.2
33 changes: 28 additions & 5 deletions tiatoolbox/wsicore/wsireader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2792,8 +2792,14 @@ class VirtualWSIReader(WSIReader):
:func:`~tiatoolbox.utils.image.sub_pixel_read`.
Attributes:
img (:class:`numpy.ndarray`)
mode (str)
img (:class:`numpy.ndarray`):
Input image as :class:`numpy.ndarray`.
mode (str):
Mode of the input image. Default is 'rgb'. Allowed values
are: rgb, bool, feature. "rgb" mode supports bright-field color images.
"bool" mode supports binary masks,
interpolation in this case will be "nearest" instead of "bicubic".
"feature" mode allows multichannel features.
Args:
input_img (str, :obj:`Path`, :class:`numpy.ndarray`):
Expand All @@ -2802,7 +2808,10 @@ class VirtualWSIReader(WSIReader):
Metadata for the virtual wsi.
mode (str):
Mode of the input image. Default is 'rgb'. Allowed values
are: rgb, bool.
are: rgb, bool, feature. "rgb" mode supports bright-field color images.
"bool" mode supports binary masks,
interpolation in this case will be "nearest" instead of "bicubic".
"feature" mode allows multichannel features.
"""

Expand All @@ -2820,15 +2829,26 @@ def __init__(
mpp=mpp,
power=power,
)
if mode.lower() not in ["rgb", "bool"]:
if mode.lower() not in ["rgb", "bool", "feature"]:
msg = "Invalid mode."
raise ValueError(msg)
self.mode = mode.lower()

if isinstance(input_img, np.ndarray):
self.img = input_img
else:
self.img = utils.imread(self.input_path)

if mode != "bool" and (
self.img.ndim == 2 or self.img.shape[2] not in [3, 4] # noqa: PLR2004
):
logger.warning(
"The image mode is set to 'feature' as the input"
" dimensions do not match with binary mask or RGB/RGBA.",
)
mode = "feature"

self.mode = mode.lower()

if info is not None:
self._m_info = info

Expand Down Expand Up @@ -3278,6 +3298,9 @@ class docstrings for more information.
if interpolation in [None, "none"]:
interpolation = None

if interpolation == "optimise" and self.mode == "bool":
interpolation = "nearest"

im_region = utils.image.sub_pixel_read(
self.img,
bounds_at_read,
Expand Down

0 comments on commit 9c51582

Please sign in to comment.