Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Enable tiling non-PANDA WSI datasets (#621)
Browse files Browse the repository at this point in the history
* Add basic dataset and environment changes

* Add loading/preproc utils

* Back-up PANDA tiling scripts

* Refactor and generalise tiling scripts

* Remove Azure scripts

* Add test WSI file

* Add preprocessing tests

* Update changelog

* Add Linux condition for cuCIM in environment.yml

* Use PANDA instead of TCGA-PRAD in test

* Leave TcgaPradDataset as an example

* Fix skipped InnerEye dataset tests

* Create and test mock slides dataset

* Remove Tests/ML/datasets from pytest discovery
  • Loading branch information
dccastro authored Dec 16, 2021
1 parent 276e0f5 commit 6a4d334
Show file tree
Hide file tree
Showing 20 changed files with 917 additions and 266 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
*.dcm filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.jpg filter=lfs diff=lfs merge=lfs -text
*.tiff filter=lfs diff=lfs merge=lfs -text
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ jobs that run in AzureML.
- ([#614](https://github.com/microsoft/InnerEye-DeepLearning/pull/614)) Checkpoint downloading falls back to looking into AzureML if no checkpoints on disk
- ([#613](https://github.com/microsoft/InnerEye-DeepLearning/pull/613)) Add additional tests for histopathology datasets
- ([#616](https://github.com/microsoft/InnerEye-DeepLearning/pull/616)) Add more histopathology configs and tests
- ([#621](https://github.com/microsoft/InnerEye-DeepLearning/pull/621)) Add WSI preprocessing functions and enable tiling more generic slide datasets

### Changed
- ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files.
Expand Down
112 changes: 110 additions & 2 deletions InnerEye/ML/Histopathology/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
# ------------------------------------------------------------------------------------------

from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import pandas as pd
import torch
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.data import Dataset

from InnerEye.ML.Histopathology.utils.naming import SlideKey


class TilesDataset(Dataset):
"""Base class for datasets of WSI tiles, iterating dictionaries of image paths and metadata.
Expand Down Expand Up @@ -71,7 +73,7 @@ def __init__(self,
self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME
dataset_df = pd.read_csv(self.dataset_csv)

columns = [self.SLIDE_ID_COLUMN, self.IMAGE_COLUMN, self.LABEL_COLUMN, self.LABEL_COLUMN,
columns = [self.SLIDE_ID_COLUMN, self.IMAGE_COLUMN, self.LABEL_COLUMN,
self.SPLIT_COLUMN, self.TILE_X_COLUMN, self.TILE_Y_COLUMN]
for column in columns:
if column is not None and column not in dataset_df.columns:
Expand Down Expand Up @@ -110,3 +112,109 @@ def get_class_weights(self) -> torch.Tensor:
classes = np.unique(slide_labels)
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=slide_labels)
return torch.as_tensor(class_weights)


class SlidesDataset(Dataset):
"""Base class for datasets of WSIs, iterating dictionaries of image paths and metadata.
The output dictionaries are indexed by `..utils.naming.SlideKey`.
:param SLIDE_ID_COLUMN: CSV column name for slide ID.
:param IMAGE_COLUMN: CSV column name for relative path to image file.
:param LABEL_COLUMN: CSV column name for tile label.
:param SPLIT_COLUMN: CSV column name for train/test split (optional).
:param TRAIN_SPLIT_LABEL: Value used to indicate the training split in `SPLIT_COLUMN`.
:param TEST_SPLIT_LABEL: Value used to indicate the test split in `SPLIT_COLUMN`.
:param DEFAULT_CSV_FILENAME: Default name of the dataset CSV at the dataset rood directory.
:param N_CLASSES: Number of classes indexed in `LABEL_COLUMN`.
"""
SLIDE_ID_COLUMN: str = 'slide_id'
IMAGE_COLUMN: str = 'image'
LABEL_COLUMN: str = 'label'
MASK_COLUMN: Optional[str] = None
SPLIT_COLUMN: Optional[str] = None

TRAIN_SPLIT_LABEL: str = 'train'
TEST_SPLIT_LABEL: str = 'test'

METADATA_COLUMNS: Tuple[str, ...] = ()

DEFAULT_CSV_FILENAME: str = "dataset.csv"

N_CLASSES: int = 1 # binary classification by default

def __init__(self,
root: Union[str, Path],
dataset_csv: Optional[Union[str, Path]] = None,
dataset_df: Optional[pd.DataFrame] = None,
train: Optional[bool] = None,
validate_columns: bool = True) -> None:
"""
:param root: Root directory of the dataset.
:param dataset_csv: Full path to a dataset CSV file, containing at least
`TILE_ID_COLUMN`, `SLIDE_ID_COLUMN`, and `IMAGE_COLUMN`. If omitted, the CSV will be read
from `"{root}/{DEFAULT_CSV_FILENAME}"`.
:param dataset_df: A potentially pre-processed dataframe in the same format as would be read
from the dataset CSV file, e.g. after some filtering. If given, overrides `dataset_csv`.
:param train: If `True`, loads only the training split (resp. `False` for test split). By
default (`None`), loads the entire dataset as-is.
:param validate_columns: Whether to call `validate_columns()` at the end of `__init__()`.
"""
if self.SPLIT_COLUMN is None and train is not None:
raise ValueError("Train/test split was specified but dataset has no split column")

self.root_dir = Path(root)

if dataset_df is not None:
self.dataset_csv = None
else:
self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME
dataset_df = pd.read_csv(self.dataset_csv)

dataset_df = dataset_df.set_index(self.SLIDE_ID_COLUMN)
if train is None:
self.dataset_df = dataset_df
else:
split = self.TRAIN_SPLIT_LABEL if train else self.TEST_SPLIT_LABEL
self.dataset_df = dataset_df[dataset_df[self.SPLIT_COLUMN] == split]

if validate_columns:
self.validate_columns()

def validate_columns(self) -> None:
"""Check that loaded dataframe contains expected columns, raises `ValueError` otherwise.
If the constructor is overloaded in a subclass, you can pass `validate_columns=False` and
call `validate_columns()` after creating derived columns, for example.
"""
columns = [self.IMAGE_COLUMN, self.LABEL_COLUMN, self.MASK_COLUMN,
self.SPLIT_COLUMN] + list(self.METADATA_COLUMNS)
for column in columns:
if column is not None and column not in self.dataset_df.columns:
raise ValueError(f"Expected column '{column}' not found in the dataframe")

def __len__(self) -> int:
return self.dataset_df.shape[0]

def __getitem__(self, index: int) -> Dict[SlideKey, Any]:
slide_id = self.dataset_df.index[index]
slide_row = self.dataset_df.loc[slide_id]
sample = {SlideKey.SLIDE_ID: slide_id}

rel_image_path = slide_row[self.IMAGE_COLUMN]
sample[SlideKey.IMAGE] = str(self.root_dir / rel_image_path)
# we're replicating this column because we want to propagate the path to the batch
sample[SlideKey.IMAGE_PATH] = sample[SlideKey.IMAGE]

if self.MASK_COLUMN:
rel_mask_path = slide_row[self.MASK_COLUMN]
sample[SlideKey.MASK] = str(self.root_dir / rel_mask_path)
sample[SlideKey.MASK_PATH] = sample[SlideKey.MASK]

sample[SlideKey.LABEL] = slide_row[self.LABEL_COLUMN]
sample[SlideKey.METADATA] = {col: slide_row[col] for col in self.METADATA_COLUMNS}
return sample

@classmethod
def has_mask(cls) -> bool:
return cls.MASK_COLUMN is not None
2 changes: 2 additions & 0 deletions InnerEye/ML/Histopathology/datasets/default_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------

PANDA_DATASET_ID = "PANDA"
PANDA_TILES_DATASET_ID = "PANDA_tiles"
TCGA_CRCK_DATASET_ID = "TCGA-CRCk"
TCGA_PRAD_DATASET_ID = "TCGA-PRAD"

DEFAULT_DATASET_LOCATION = "/tmp/datasets/"
PANDA_DATASET_DIR = DEFAULT_DATASET_LOCATION + PANDA_DATASET_ID
PANDA_TILES_DATASET_DIR = DEFAULT_DATASET_LOCATION + PANDA_TILES_DATASET_ID
TCGA_CRCK_DATASET_DIR = DEFAULT_DATASET_LOCATION + TCGA_CRCK_DATASET_ID
TCGA_PRAD_DATASET_DIR = DEFAULT_DATASET_LOCATION + TCGA_PRAD_DATASET_ID
72 changes: 32 additions & 40 deletions InnerEye/ML/Histopathology/datasets/panda_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,50 +7,42 @@
from typing import Any, Dict, Union, Optional

import pandas as pd
from cucim import CuImage
from health_ml.utils import box_utils
from monai.config import KeysCollection
from monai.data.image_reader import ImageReader, WSIReader
from monai.transforms import MapTransform
from openslide import OpenSlide
from torch.utils.data import Dataset

from health_ml.utils import box_utils
from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset


class PandaDataset(Dataset):
class PandaDataset(SlidesDataset):
"""Dataset class for loading files from the PANDA challenge dataset.
Iterating over this dataset returns a dictionary containing the `'image_id'`, paths to the `'image'`
and `'mask'` files, and the remaining meta-data from the original dataset (`'data_provider'`,
`'isup_grade'`, and `'gleason_score'`).
Iterating over this dataset returns a dictionary following the `SlideKey` schema plus meta-data
from the original dataset (`'data_provider'`, `'isup_grade'`, and `'gleason_score'`).
Ref.: https://www.kaggle.com/c/prostate-cancer-grade-assessment/overview
"""
def __init__(self, root_dir: Union[str, Path], n_slides: Optional[int] = None,
frac_slides: Optional[float] = None) -> None:
super().__init__()
self.root_dir = Path(root_dir)
self.train_df = pd.read_csv(self.root_dir / "train.csv", index_col='image_id')
if n_slides or frac_slides:
self.train_df = self.train_df.sample(n=n_slides, frac=frac_slides, replace=False,
random_state=1234)

def __len__(self) -> int:
return self.train_df.shape[0]

def _get_image_path(self, image_id: str) -> Path:
return self.root_dir / "train_images" / f"{image_id}.tiff"

def _get_mask_path(self, image_id: str) -> Path:
return self.root_dir / "train_label_masks" / f"{image_id}_mask.tiff"

def __getitem__(self, index: int) -> Dict:
image_id = self.train_df.index[index]
return {
'image_id': image_id,
'image': str(self._get_image_path(image_id).absolute()),
'mask': str(self._get_mask_path(image_id).absolute()),
**self.train_df.loc[image_id].to_dict()
}
SLIDE_ID_COLUMN = 'image_id'
IMAGE_COLUMN = 'image'
MASK_COLUMN = 'mask'
LABEL_COLUMN = 'isup_grade'

METADATA_COLUMNS = ('data_provider', 'isup_grade', 'gleason_score')

DEFAULT_CSV_FILENAME = "train.csv"

def __init__(self,
root: Union[str, Path],
dataset_csv: Optional[Union[str, Path]] = None,
dataset_df: Optional[pd.DataFrame] = None) -> None:
super().__init__(root, dataset_csv, dataset_df, validate_columns=False)
# PANDA CSV does not come with paths for image and mask files
slide_ids = self.dataset_df.index
self.dataset_df[self.IMAGE_COLUMN] = "train_images/" + slide_ids + ".tiff"
self.dataset_df[self.MASK_COLUMN] = "train_label_masks/" + slide_ids + "_mask.tiff"
self.validate_columns()


# MONAI's convention is that dictionary transforms have a 'd' suffix in the class name
Expand Down Expand Up @@ -96,25 +88,25 @@ def __init__(self, reader: WSIReader, image_key: str = 'image', mask_key: str =
self.margin = margin
self.kwargs = kwargs

def _get_bounding_box(self, mask_obj: OpenSlide) -> box_utils.Box:
def _get_bounding_box(self, mask_obj: CuImage) -> box_utils.Box:
# Estimate bounding box at the lowest resolution (i.e. highest level)
highest_level = mask_obj.level_count - 1
scale = mask_obj.level_downsamples[highest_level]
highest_level = mask_obj.resolutions['level_count'] - 1
scale = mask_obj.resolutions['level_downsamples'][highest_level]
mask, _ = self.reader.get_data(mask_obj, level=highest_level) # loaded as RGB PIL image

foreground_mask = mask[0] > 0 # PANDA segmentation mask is in 'R' channel
bbox = scale * box_utils.get_bounding_box(foreground_mask).add_margin(self.margin)
return bbox

def __call__(self, data: Dict) -> Dict:
mask_obj: OpenSlide = self.reader.read(data[self.mask_key])
image_obj: OpenSlide = self.reader.read(data[self.image_key])
mask_obj: CuImage = self.reader.read(data[self.mask_key])
image_obj: CuImage = self.reader.read(data[self.image_key])

level0_bbox = self._get_bounding_box(mask_obj)

# OpenSlide takes absolute location coordinates in the level 0 reference frame,
# cuCIM/OpenSlide take absolute location coordinates in the level 0 reference frame,
# but relative region size in pixels at the chosen level
scale = mask_obj.level_downsamples[self.level]
scale = mask_obj.resolutions['level_downsamples'][self.level]
scaled_bbox = level0_bbox / scale
get_data_kwargs = dict(location=(level0_bbox.x, level0_bbox.y),
size=(scaled_bbox.w, scaled_bbox.h),
Expand Down
42 changes: 11 additions & 31 deletions InnerEye/ML/Histopathology/datasets/tcga_prad_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
# ------------------------------------------------------------------------------------------

from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Optional, Union

import pandas as pd
from torch.utils.data import Dataset

from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset

class TcgaPradDataset(Dataset):

class TcgaPradDataset(SlidesDataset):
"""Dataset class for loading TCGA-PRAD slides.
Iterating over this dataset returns a dictionary containing:
Expand All @@ -19,44 +20,23 @@ class TcgaPradDataset(Dataset):
- `'image_path'` (str): absolute slide image path
- `'label'` (int, 0 or 1): label for predicting positive or negative
"""
SLIDE_ID_COLUMN: str = 'slide_id'
CASE_ID_COLUMN: str = 'case_id'
IMAGE_COLUMN: str = 'image_path'
LABEL_COLUMN: str = 'label'

DEFAULT_CSV_FILENAME: str = "dataset.csv"

def __init__(self, root_dir: Union[str, Path],
def __init__(self, root: Union[str, Path],
dataset_csv: Optional[Union[str, Path]] = None,
dataset_df: Optional[pd.DataFrame] = None,) -> None:
dataset_df: Optional[pd.DataFrame] = None) -> None:
"""
:param root: Root directory of the dataset.
:param dataset_csv: Full path to a dataset CSV file. If omitted, the CSV will be read from
`"{root}/{DEFAULT_CSV_FILENAME}"`.
:param dataset_df: A potentially pre-processed dataframe in the same format as would be read
from the dataset CSV file, e.g. after some filtering. If given, overrides `dataset_csv`.
"""
self.root_dir = Path(root_dir)

if dataset_df is not None:
self.dataset_csv = None
else:
self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME
dataset_df = pd.read_csv(self.dataset_csv)

dataset_df = dataset_df.set_index(self.SLIDE_ID_COLUMN)
dataset_df[self.LABEL_COLUMN] = (dataset_df['label1_mutation']
| dataset_df['label2_mutation']).astype(int)
self.dataset_df = dataset_df

def __len__(self) -> int:
return self.dataset_df.shape[0]

def __getitem__(self, index: int) -> Dict[str, Any]:
slide_id = self.dataset_df.index[index]
sample = {
self.SLIDE_ID_COLUMN: slide_id,
**self.dataset_df.loc[slide_id].to_dict()
}
sample[self.IMAGE_COLUMN] = str(self.root_dir / sample.pop(self.IMAGE_COLUMN))
return sample
super().__init__(root, dataset_csv, dataset_df, validate_columns=False)
# Example of how to define a custom label column from existing columns:
self.dataset_df[self.LABEL_COLUMN] = (self.dataset_df['label1']
| self.dataset_df['label2']).astype(int)
self.validate_columns()
Loading

0 comments on commit 6a4d334

Please sign in to comment.