Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Plotting heatmap and thumbnails for test PANDA slides #634

Merged
merged 25 commits into from
Jan 19, 2022
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ jobs that run in AzureML.
- ([#613](https://github.com/microsoft/InnerEye-DeepLearning/pull/613)) Add additional tests for histopathology datasets
- ([#616](https://github.com/microsoft/InnerEye-DeepLearning/pull/616)) Add more histopathology configs and tests
- ([#621](https://github.com/microsoft/InnerEye-DeepLearning/pull/621)) Add WSI preprocessing functions and enable tiling more generic slide datasets
- ([#634](https://github.com/microsoft/InnerEye-DeepLearning/pull/634)) Add WSI heatmaps and thumbnails to standard test outputs

### Changed
- ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files.
Expand Down
49 changes: 40 additions & 9 deletions InnerEye/ML/Histopathology/models/deepmil.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import numpy as np
from typing import Any, Callable, Dict, Optional, Tuple, List
import torch
import matplotlib.pyplot as plt
import more_itertools as mi

from pytorch_lightning import LightningModule
from torch import Tensor, argmax, mode, nn, no_grad, optim, round
Expand All @@ -17,9 +19,13 @@
from InnerEye.Common import fixed_paths
from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset
from InnerEye.ML.Histopathology.models.encoders import TileEncoder
from InnerEye.ML.Histopathology.utils.metrics_utils import select_k_tiles, plot_slide_noxy, plot_scores_hist
from InnerEye.ML.Histopathology.utils.metrics_utils import select_k_tiles, plot_slide_noxy, plot_scores_hist, plot_heatmap_overlay, plot_slide
from InnerEye.ML.Histopathology.utils.naming import ResultsKey

from monai.data.dataset import Dataset
from InnerEye.ML.Histopathology.utils.viz_utils import load_image_dict
from InnerEye.ML.Histopathology.utils.naming import SlideKey


RESULTS_COLS = [ResultsKey.SLIDE_ID, ResultsKey.TILE_ID, ResultsKey.IMAGE_PATH, ResultsKey.PROB,
ResultsKey.PRED_LABEL, ResultsKey.TRUE_LABEL, ResultsKey.BAG_ATTN]
Expand All @@ -46,7 +52,9 @@ def __init__(self,
weight_decay: float = 1e-4,
adam_betas: Tuple[float, float] = (0.9, 0.99),
verbose: bool = False,
) -> None:
slide_dataset: Dataset = Dataset(data=[]),
dccastro marked this conversation as resolved.
Show resolved Hide resolved
tile_size: int = 224,
level: Optional[int] = 1) -> None:
dccastro marked this conversation as resolved.
Show resolved Hide resolved
"""
:param label_column: Label key for input batch dictionary.
:param n_classes: Number of output classes for MIL prediction.
Expand All @@ -61,6 +69,9 @@ def __init__(self,
:param weight_decay: Weight decay parameter for L2 regularisation.
:param adam_betas: Beta parameters for Adam optimiser.
:param verbose: if True statements about memory usage are output at each step
:param slide_dataset: Slide dataset object, if available.
:param tile_size: The size of each tile (default=224).
:param level: The downsampling level (e.g. 0, 1, 2) of the tiles if available (default=1).
"""
super().__init__()

Expand All @@ -79,7 +90,13 @@ def __init__(self,
self.weight_decay = weight_decay
self.adam_betas = adam_betas

# Slide specific attributes
self.slide_dataset = slide_dataset
self.tile_size = tile_size
self.level = level

self.save_hyperparameters()

self.verbose = verbose

self.aggregation_fn, self.num_pooling = self.get_pooling()
Expand Down Expand Up @@ -288,29 +305,43 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore
print("Selecting tiles ...")
fn_top_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('lowest_pred', 'highest_att'))
fn_bottom_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('lowest_pred', 'lowest_att'))
tp_top_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('highes_pred', 'highest_att'))
tp_top_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('highest_pred', 'highest_att'))
tp_bottom_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('highest_pred', 'lowest_att'))
report_cases = {'TP': [tp_top_tiles, tp_bottom_tiles], 'FN': [fn_top_tiles, fn_bottom_tiles]}

for key in report_cases.keys():
print(f"Plotting {key} ...")
print(f"Plotting {key} (tiles, thumbnails, attention heatmaps)...")
key_folder_path = outputs_fig_path / f'{key}'
Path(key_folder_path).mkdir(parents=True, exist_ok=True)
nslides = len(report_cases[key][0])
for i in range(nslides):
slide, score, paths, top_attn = report_cases[key][0][i]
fig = plot_slide_noxy(slide, score, paths, top_attn, key + '_top', ncols=4)
figpath = Path(key_folder_path, f'{slide}_top.png')
fig.savefig(figpath, bbox_inches='tight')
self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_top.png'))

slide, score, paths, bottom_attn = report_cases[key][1][i]
fig = plot_slide_noxy(slide, score, paths, bottom_attn, key + '_bottom', ncols=4)
harshita-s marked this conversation as resolved.
Show resolved Hide resolved
figpath = Path(key_folder_path, f'{slide}_bottom.png')
fig.savefig(figpath, bbox_inches='tight')
self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_bottom.png'))

if len(self.slide_dataset) > 0:
slide_dict = mi.first_true(self.slide_dataset, pred=lambda entry: entry[SlideKey.SLIDE_ID] == slide) # type: ignore
load_image_dict(slide_dict, level=self.level, margin=0) # type: ignore
harshita-s marked this conversation as resolved.
Show resolved Hide resolved
slide_image = slide_dict[SlideKey.IMAGE]
location_bbox = slide_dict[SlideKey.LOCATION]

fig = plot_slide(slide_image=slide_image, scale=1.0)
self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_thumbnail.png'))
fig = plot_heatmap_overlay(slide=slide, slide_image=slide_image, results=results,
location_bbox=location_bbox, tile_size=self.tile_size, level=slide_dict['level'])
harshita-s marked this conversation as resolved.
Show resolved Hide resolved
self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_heatmap.png'))

print("Plotting histogram ...")
fig = plot_scores_hist(results)
fig.savefig(outputs_fig_path / 'hist_scores.png', bbox_inches='tight')
self.save_figure(fig=fig, figpath=outputs_fig_path / 'hist_scores.png')

@staticmethod
def save_figure(fig: plt.figure, figpath: Path) -> None:
fig.savefig(figpath, bbox_inches='tight')
harshita-s marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def normalize_dict_for_df(dict_old: Dict[str, Any], use_gpu: bool) -> Dict:
Expand Down
31 changes: 31 additions & 0 deletions InnerEye/ML/Histopathology/utils/heatmap_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------

from typing import List
import numpy as np


def location_selected_tiles(tile_coords: np.ndarray,
location_bbox: List[int],
level: int) -> np.ndarray:
""" Return the scaled and shifted tile co-ordinates for selected tiles in the slide.
:param tile_coords: XY tile coordinates, assumed to be spaced by multiples of `tile_size` (shape: [N, 2]) in original resolution.
:param location_bbox: Location of the bounding box on the slide in original resolution.
:param level: The downsampling level (e.g. 0, 1, 2) of the tiles if available.
(e.g. PANDA levels are 0 for original, 1 for 4x downsampled, 2 for 16x downsampled).
"""
level_dict = {0: 1, 1: 4, 2: 16}
factor = level_dict[level]

x_tr, y_tr = location_bbox
tile_xs, tile_ys = tile_coords.T
tile_xs = tile_xs - x_tr
tile_ys = tile_ys - y_tr
tile_xs = tile_xs//factor
tile_ys = tile_ys//factor

sel_coords = np.transpose([tile_xs.tolist(), tile_ys.tolist()])

return sel_coords
68 changes: 68 additions & 0 deletions InnerEye/ML/Histopathology/utils/metrics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
import torch
import matplotlib.pyplot as plt
from math import ceil
import numpy as np
import matplotlib.patches as patches
import matplotlib.collections as collection

from InnerEye.ML.Histopathology.models.transforms import load_pil_image
from InnerEye.ML.Histopathology.utils.naming import ResultsKey
from InnerEye.ML.Histopathology.utils.heatmap_utils import location_selected_tiles


def select_k_tiles(results: Dict, n_tiles: int = 5, n_slides: int = 5, label: int = 1,
Expand Down Expand Up @@ -97,3 +101,67 @@ def plot_slide_noxy(slide: str, score: float, paths: List, attn: List, case: str
for i in range(len(axs.ravel())):
axs.ravel()[i].set_axis_off()
return fig


def plot_slide(slide_image: np.ndarray, scale: float) -> plt.figure:
"""Plots a slide thumbnail from a given slide image and scale.
:param slide_image: Numpy array of the slide image (shape: [3, H, W]).
:return: matplotlib figure of the slide thumbnail.
"""
fig, ax = plt.subplots()
slide_image = slide_image.transpose(1, 2, 0)
ax.imshow(slide_image)
ax.set_axis_off()
original_size = fig.get_size_inches()
fig.set_size_inches((original_size[0]*scale, original_size[1]*scale))
return fig


def plot_heatmap_overlay(slide: str,
slide_image: np.ndarray,
results: Dict[str, List[Any]],
location_bbox: List[int],
tile_size: int = 224,
level: int = 1) -> plt.figure:
"""Plots heatmap of selected tiles (e.g. tiles in a bag) overlay on the corresponding slide.
:param slide: slide identifier.
:param slide_image: Numpy array of the slide image (shape: [3, H, W]).
:param results: Dict containing ResultsKey keys (e.g. slide id) and values as lists of output slides.
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
:param tile_size: Size of each tile. Default 224.
:param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original, 1 for 4x downsampled, 2 for 16x downsampled). Default 1.
:param location_bbox: Location of the bounding box of the slide.
:return: matplotlib figure of the heatmap of the given tiles on slide.
"""
fig, ax = plt.subplots()
slide_image = slide_image.transpose(1, 2, 0)
ax.imshow(slide_image)
ax.set_xlim(0, slide_image.shape[1])
ax.set_ylim(slide_image.shape[0], 0)

coords = []
slide_ids = [item[0] for item in results[ResultsKey.SLIDE_ID]]
slide_idx = slide_ids.index(slide)
attentions = results[ResultsKey.BAG_ATTN][slide_idx]

# for each tile in the bag
for tile_idx in range(len(results[ResultsKey.IMAGE_PATH][slide_idx])):
tile_coords = np.transpose(np.array([results[ResultsKey.TILE_X][slide_idx][tile_idx].cpu().numpy(),
results[ResultsKey.TILE_Y][slide_idx][tile_idx].cpu().numpy()]))
coords.append(tile_coords)

coords = np.array(coords)
attentions = np.array(attentions.cpu()).reshape(-1)

sel_coords = location_selected_tiles(tile_coords=coords, location_bbox=location_bbox, level=level)
cmap = plt.cm.get_cmap('jet')
rects = []
for i in range(sel_coords.shape[0]):
rect = patches.Rectangle((sel_coords[i][0], sel_coords[i][1]), tile_size, tile_size)
rects.append(rect)

pc = collection.PatchCollection(rects, match_original=True, cmap=cmap, alpha=.5, edgecolor=None)
pc.set_array(np.array(attentions))
pc.set_clim([0, 1])
ax.add_collection(pc)
plt.colorbar(pc, ax=ax)
return fig
1 change: 1 addition & 0 deletions InnerEye/ML/Histopathology/utils/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class SlideKey(str, Enum):
ORIGIN = 'origin'
FOREGROUND_THRESHOLD = 'foreground_threshold'
METADATA = 'metadata'
LOCATION = 'location'


class TileKey(str, Enum):
Expand Down
4 changes: 4 additions & 0 deletions InnerEye/ML/configs/histo_configs/classification/BaseMIL.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import param
from torch import nn
from torchvision.models.resnet import resnet18
from monai.data.dataset import Dataset

from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer
from InnerEye.ML.lightning_container import LightningContainer
Expand Down Expand Up @@ -100,3 +101,6 @@ def create_model(self) -> DeepMILModule:

def get_data_module(self) -> TilesDataModule:
raise NotImplementedError

def get_slide_dataset(self) -> Dataset:
raise NotImplementedError
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
damage response defect classification directly from H&E whole-slide images. arXiv:2107.09405
"""
from pathlib import Path
from typing import Any, Dict
from typing import Any, List
import os

from monai.transforms import Compose
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks import Callback

from health_ml.networks.layers.attention_layers import GatedAttentionLayer
from health_azure.utils import get_workspace
Expand Down Expand Up @@ -125,9 +126,8 @@ def get_data_module(self) -> TilesDataModule:
cross_validation_split_index=self.cross_validation_split_index,
)

def get_trainer_arguments(self) -> Dict[str, Any]:
# These arguments will be passed through to the Lightning trainer.
return {"callbacks": self.callbacks}
def get_callbacks(self) -> List[Callback]:
return super().get_callbacks() + [self.callbacks]

def get_path_to_best_checkpoint(self) -> Path:
"""
Expand Down
44 changes: 36 additions & 8 deletions InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------

from typing import Any, Dict
from typing import Any, List
from pathlib import Path
import os
from monai.transforms import Compose
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks import Callback
from monai.data.dataset import Dataset

from health_azure.utils import CheckpointDownloader
from health_azure.utils import get_workspace
from health_azure.utils import get_workspace, is_running_in_azure_ml
from health_ml.networks.layers.attention_layers import GatedAttentionLayer
from InnerEye.Common import fixed_paths
from InnerEye.ML.Histopathology.datamodules.panda_module import PandaTilesDataModule
Expand All @@ -26,8 +28,11 @@
ImageNetEncoder,
ImageNetSimCLREncoder,
InnerEyeSSLEncoder,
IdentityEncoder
)
from InnerEye.ML.configs.histo_configs.classification.BaseMIL import BaseMIL
from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset
from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule


class DeepSMILEPanda(BaseMIL):
Expand All @@ -38,6 +43,8 @@ def __init__(self, **kwargs: Any) -> None:
# declared in DatasetParams:
local_dataset=Path("/tmp/datasets/PANDA_tiles"),
azure_dataset_id="PANDA_tiles",
extra_azure_dataset_ids=["PANDA"],
extra_local_dataset_paths=[Path("/tmp/datasets/PANDA")],
# To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI
# declared in TrainerParams:
num_epochs=200,
Expand All @@ -48,11 +55,12 @@ def __init__(self, **kwargs: Any) -> None:
# declared in OptimizerParams:
l_rate=5e-4,
weight_decay=1e-4,
adam_betas=(0.9, 0.99),
)
adam_betas=(0.9, 0.99))
default_kwargs.update(kwargs)
super().__init__(**default_kwargs)
super().__init__(**default_kwargs)
if not is_running_in_azure_ml():
self.num_epochs = 1
self.best_checkpoint_filename = "checkpoint_max_val_auroc"
self.best_checkpoint_filename_with_suffix = (
self.best_checkpoint_filename + ".ckpt"
Expand Down Expand Up @@ -109,9 +117,29 @@ def get_data_module(self) -> PandaTilesDataModule:
cross_validation_split_index=self.cross_validation_split_index,
)

def get_trainer_arguments(self) -> Dict[str, Any]:
# These arguments will be passed through to the Lightning trainer.
return {"callbacks": self.callbacks}
def create_model(self) -> DeepMILModule:
self.data_module = self.get_data_module()
# Encoding is done in the datamodule, so here we provide instead a dummy
# no-op IdentityEncoder to be used inside the model
self.slide_dataset = self.get_slide_dataset()
self.level = 1
return DeepMILModule(encoder=IdentityEncoder(input_dim=(self.encoder.num_encoding,)),
label_column=self.data_module.train_dataset.LABEL_COLUMN,
n_classes=self.data_module.train_dataset.N_CLASSES,
pooling_layer=self.get_pooling_layer(),
class_weights=self.data_module.class_weights,
l_rate=self.l_rate,
weight_decay=self.weight_decay,
adam_betas=self.adam_betas,
slide_dataset=self.get_slide_dataset(),
tile_size=self.tile_size,
level=self.level)

def get_slide_dataset(self) -> Dataset:
return Dataset(PandaDataset(root=self.extra_local_dataset_paths[0])) # type: ignore
harshita-s marked this conversation as resolved.
Show resolved Hide resolved

def get_callbacks(self) -> List[Callback]:
return super().get_callbacks() + [self.callbacks]

def get_path_to_best_checkpoint(self) -> Path:
"""
Expand All @@ -135,7 +163,7 @@ def get_path_to_best_checkpoint(self) -> Path:
if checkpoint_path.is_file():
return checkpoint_path

raise ValueError("Path to best checkpoint not found")
raise ValueError("Path to best checkpoint not found")


class PandaImageNetMIL(DeepSMILEPanda):
Expand Down
Loading