Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡ Faster Filtering in Patch Dataset #571

Merged
merged 10 commits into from
Apr 6, 2023
68 changes: 17 additions & 51 deletions tests/test_patch_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def test_get_patch_extractor(source_image, patch_extr_csv):
)

assert isinstance(points, patchextraction.PointsPatchExtractor)
assert len(points) == 1860

sliding_window = patchextraction.get_patch_extractor(
input_img=input_img,
Expand Down Expand Up @@ -424,97 +425,62 @@ def test_filter_coordinates():
mask = np.zeros([9, 6])
mask[0:4, 3:8] = 1 # will flag first 2
mask_reader = VirtualWSIReader(mask)

# Tests for (original) filter_coordinates method
# functionality test
flag_list = PatchExtractor.filter_coordinates(
mask_reader, bbox_list, resolution=1.0, units="baseline"
)
assert np.sum(flag_list - np.array([1, 1, 0, 0, 0, 0])) == 0

# Test for bad mask input
with pytest.raises(
ValueError, match="`mask_reader` should be wsireader.VirtualWSIReader."
):
PatchExtractor.filter_coordinates(
mask, bbox_list, resolution=1.0, units="baseline"
)

# Test for bad bbox coordinate list in the input
with pytest.raises(ValueError, match=r".*should be ndarray of integer type.*"):
PatchExtractor.filter_coordinates(
mask_reader, bbox_list.tolist(), resolution=1.0, units="baseline"
)

# Test for incomplete coordinate list
with pytest.raises(ValueError, match=r".*`coordinates_list` of shape.*"):
PatchExtractor.filter_coordinates(
mask_reader, bbox_list[:, :2], resolution=1.0, units="baseline"
)
slide_shape = [6, 9] # slide shape (w, h) at requested resolution

###############################################
# Tests for filter_coordinates_fast (new) method
# Tests for filter_coordinates (new) method
_info = mask_reader.info
_info.mpp = 1.0
mask_reader._m_info = _info

# functionality test
flag_list = PatchExtractor.filter_coordinates_fast(
flag_list = PatchExtractor.filter_coordinates(
mask_reader,
bbox_list,
coordinate_resolution=1.0,
coordinate_units="mpp",
mask_resolution=1,
slide_shape,
)
assert np.sum(flag_list - np.array([1, 1, 0, 0, 0, 0])) == 0
flag_list = PatchExtractor.filter_coordinates_fast(
mask_reader, bbox_list, coordinate_resolution=(1.0, 1.0), coordinate_units="mpp"
)
flag_list = PatchExtractor.filter_coordinates(mask_reader, bbox_list, slide_shape)

# Test for bad mask input
with pytest.raises(
ValueError, match="`mask_reader` should be wsireader.VirtualWSIReader."
):
PatchExtractor.filter_coordinates_fast(
PatchExtractor.filter_coordinates(
mask,
bbox_list,
coordinate_resolution=1.0,
coordinate_units="mpp",
slide_shape,
)

# Test for bad bbox coordinate list in the input
with pytest.raises(ValueError, match=r".*should be ndarray of integer type.*"):
PatchExtractor.filter_coordinates_fast(
PatchExtractor.filter_coordinates(
mask_reader,
bbox_list.tolist(),
coordinate_resolution=1,
coordinate_units="mpp",
slide_shape,
)

# Test for incomplete coordinate list
with pytest.raises(ValueError, match=r".*`coordinates_list` must be of shape.*"):
PatchExtractor.filter_coordinates_fast(
PatchExtractor.filter_coordinates(
mask_reader,
bbox_list[:, :2],
coordinate_resolution=1,
coordinate_units="mpp",
slide_shape,
)

# Test for put of range min_mask_ratio
with pytest.raises(ValueError, match="`min_mask_ratio` must be between 0 and 1."):
PatchExtractor.filter_coordinates_fast(
PatchExtractor.filter_coordinates(
mask_reader,
bbox_list,
coordinate_resolution=1.0,
coordinate_units="mpp",
slide_shape,
min_mask_ratio=-0.5,
)
with pytest.raises(ValueError, match="`min_mask_ratio` must be between 0 and 1."):
PatchExtractor.filter_coordinates_fast(
PatchExtractor.filter_coordinates(
mask_reader,
bbox_list,
coordinate_resolution=1.0,
coordinate_units="mpp",
slide_shape,
min_mask_ratio=1.1,
)

Expand All @@ -529,7 +495,7 @@ def test_mask_based_patch_extractor_ndpi(sample_ndpi):

# Generating a test mask to read patches from
mask_dim = (int(slide_dimensions[0] / 10), int(slide_dimensions[1] / 10))
wsi_mask = np.zeros(mask_dim, dtype=np.uint8)
wsi_mask = np.zeros(mask_dim[::-1], dtype=np.uint8) # reverse as dims are (w, h)
# masking two column to extract patch from
wsi_mask[:, :2] = 255

Expand Down
36 changes: 22 additions & 14 deletions tiatoolbox/models/dataset/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,8 @@ class WSIPatchDataset(dataset_abc.PatchDatasetABC):
units:
See (:class:`.WSIReader`) for details.
preproc_func:
Preprocessing function used to transform the input data. If
supplied, then torch.Compose will be used on the input
preprocs. preprocs is a list of torchvision transforms for
preprocessing the image. The transforms will be applied in
the order that they are given in the list. For more
information, visit the following link:
https://pytorch.org/vision/stable/transforms.html.
Preprocessing function used to transform the input data. It will
be called on each patch before returning it.

"""

Expand All @@ -162,6 +157,8 @@ def __init__(
resolution=None,
units=None,
auto_get_mask=True,
min_mask_ratio=0,
preproc_func=None,
):
"""Create a WSI-level patch dataset.

Expand All @@ -187,12 +184,22 @@ def __init__(
`units`. Expected to be positive and of (height, width).
Note, this is not at level 0.
resolution:
Check (:class:`.WSIReader`) for details. When
`mode='tile'`, value is fixed to be `resolution=1.0` and
`units='baseline'` units: check (:class:`.WSIReader`) for
details.
Check (:class:`.WSIReader`) for details. When
`mode='tile'`, value is fixed to be `resolution=1.0` and
`units='baseline'` units: check (:class:`.WSIReader`) for
details.
units:
Units in which `resolution` is defined.
auto_get_mask:
If `True`, then automatically get simple threshold mask using
WSIReader.tissue_mask() function.
min_mask_ratio:
Only patches with positive area percentage above this value are
included. Defaults to 0.
preproc_func:
Preprocessing function used to transform the input data.
Preprocessing function used to transform the input data. If
supplied, the function will be called on each patch before
returning it.

Examples:
>>> # A user defined preproc func and expected behavior
Expand Down Expand Up @@ -233,6 +240,7 @@ def __init__(
):
raise ValueError(f"Invalid `stride_shape` value {stride_shape}.")

self.preproc_func = preproc_func
img_path = pathlib.Path(img_path)
if mode == "wsi":
self.reader = WSIReader.open(img_path)
Expand Down Expand Up @@ -299,8 +307,8 @@ def __init__(
selected = PatchExtractor.filter_coordinates(
mask_reader, # must be at the same resolution
self.inputs, # must already be at requested resolution
resolution=resolution,
units=units,
wsi_shape=wsi_shape,
min_mask_ratio=min_mask_ratio,
)
self.inputs = self.inputs[selected]

Expand Down
Loading