diff --git a/pyproject.toml b/pyproject.toml index a9a8f65..8a870fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "monai", "nibabel", "pytorch-lightning", + "scikit-image", "scikit-learn", "SimpleITK", "sitk-cli", diff --git a/scripts/make_datalist.py b/scripts/make_datalist.py index 8ce70eb..5e7477b 100644 --- a/scripts/make_datalist.py +++ b/scripts/make_datalist.py @@ -5,9 +5,13 @@ import typer from segmantic.image.labels import load_tissue_list +from segmantic.seg.dataset import PairedDataSet from segmantic.utils.file_iterators import find_matching_files +app = typer.Typer() + +@app.command() def make_datalist( data_dir: Path = typer.Option( ..., @@ -73,8 +77,60 @@ def make_datalist( return datalist_path.write_text(json.dumps(data_config, indent=2)) +@app.command() +def extend_datalist( + data_dir: Path = typer.Option( + ..., + help="root data directory. Paths in datalist will be relative to this directory", + ), + image_dir: Path = typer.Option(..., help="Directory containing images"), + labels_dir: Path = typer.Option(None, help="Directory containing labels"), + datalist_path: Path = typer.Option(..., help="Filename of input datalist"), + output_path: Path = typer.Option(..., help="Filename of output datalist"), + image_glob: str = "*.nii.gz", + labels_glob: str = "*.nii.gz", +): + ds = PairedDataSet.load_from_json(datalist_path) + images = [d["image"].name.lower() for d in ds.training_files()] + images += [d["image"].name.lower() for d in ds.validation_files()] + images += [d["image"].name.lower() for d in ds.test_files()] + + if image_dir.is_absolute(): + image_dir = image_dir.relative_to(data_dir) + if labels_dir.is_absolute(): + labels_dir = labels_dir.relative_to(data_dir) + + matches = find_matching_files( + [data_dir / image_dir / image_glob, data_dir / labels_dir / labels_glob] + ) + + training_data = list(ds.training_files()) + for p in matches: + image_name = p[0].name.lower() + if image_name not in images: + training_data.append({"image": p[0], "label": p[1]}) + + def make_relative(d: dict[str, Path]): + return {key: str(d[key].relative_to(data_dir)) for key in d} + + data_config = json.loads(datalist_path.read_text()) + + data_config["training"] = [make_relative(v) for v in training_data] + data_config["validation"] = [make_relative(v) for v in ds.validation_files()] + # data_config["test"] = (make_relative(v)["image"] for v in ds.test_files()) + return output_path.write_text(json.dumps(data_config, indent=2)) + + +@app.command() +def print_stats(datalist: Path): + ds = PairedDataSet.load_from_json(datalist) + print(f"Training cases: {len(ds.training_files())}") + print(f"Validation cases: {len(ds.validation_files())}") + print(f"Test cases: {len(ds.test_files())}") + + def main(): - typer.run(make_datalist) + app() if __name__ == "__main__": diff --git a/scripts/preprocess.py b/scripts/preprocess.py new file mode 100644 index 0000000..ec7f3e2 --- /dev/null +++ b/scripts/preprocess.py @@ -0,0 +1,47 @@ +from pathlib import Path + +import typer +from monai.transforms import ( + Compose, + CropForegroundd, + ForegroundMaskd, + LoadImaged, + NormalizeIntensityd, + SaveImaged, +) + +from segmantic.seg.transforms import SelectChanneld + + +def pre_process(input_file: Path, output_dir: Path, margin: int = 0, channel: int = 0): + + transforms = Compose( + [ + LoadImaged( + keys="img", + reader="ITKReader", + image_only=False, + ensure_channel_first=True, + ), + SelectChanneld(keys="img", channel=channel, new_key_postfix="_0"), + ForegroundMaskd(keys="img_0", invert=True, new_key_prefix="mask"), + CropForegroundd( + keys="img", source_key="maskimg_0", allow_smaller=False, margin=margin + ), + NormalizeIntensityd(keys="img"), + SaveImaged( + keys="img", + writer="ITKWriter", + output_dir=output_dir, + output_postfix="", + resample=False, + separate_folder=False, + print_log=True, + ), + ] + ) + transforms({"img": input_file}) + + +if __name__ == "__main__": + typer.run(pre_process) diff --git a/src/segmantic/__init__.py b/src/segmantic/__init__.py index 108cc2d..a3e6f9a 100644 --- a/src/segmantic/__init__.py +++ b/src/segmantic/__init__.py @@ -1,4 +1,4 @@ """ ML-based segmentation for medical images """ -__version__ = "0.4.0" +__version__ = "0.5.0" diff --git a/src/segmantic/commands/monai_unet_cli.py b/src/segmantic/commands/monai_unet_cli.py index 540aefe..ca64fde 100644 --- a/src/segmantic/commands/monai_unet_cli.py +++ b/src/segmantic/commands/monai_unet_cli.py @@ -174,7 +174,7 @@ def predict( None, "--tissue-list", "-t", help="label descriptors in iSEG format" ), results_dir: Path = typer.Option( - None, "--results-dir", "-r", help="output directory" + ..., "--results-dir", "-r", help="output directory" ), spacing: list[float] = typer.Option( [], "--spacing", help="if specified, the image is first resampled" diff --git a/src/segmantic/seg/ensemble.py b/src/segmantic/seg/ensemble.py new file mode 100644 index 0000000..41cb7f8 --- /dev/null +++ b/src/segmantic/seg/ensemble.py @@ -0,0 +1,87 @@ +from collections.abc import Sequence +from typing import Optional, Union + +import torch +from monai.config import KeysCollection +from monai.config.type_definitions import NdarrayOrTensor +from monai.networks.utils import one_hot +from monai.transforms.post.array import Ensemble +from monai.transforms.post.dictionary import Ensembled +from monai.transforms.transform import Transform +from monai.utils import TransformBackends + + +class SelectBestEnsemble(Ensemble, Transform): + """ + Execute select best ensemble on the input data. + The input data can be a list or tuple of PyTorch Tensor with shape: [C[, H, W, D]], + Or a single PyTorch Tensor with shape: [E[, C, H, W, D]], the `E` dimension represents + the output data from different models. + Typically, the input data is model output of segmentation task or classification task. + + Note: + This select best transform expects the input data is discrete single channel values. + It selects the tissue of the model which performed best in a generalization analysis. + The mapping is saved in the label_model_dict. + The output data has the same shape as every item of the input data. + + Args: + label_model_dict: dictionary containing the best models index for each tissue and + the tissue labels. + """ + + backend = [TransformBackends.TORCH] + + def __init__(self, label_model_dict: dict[int, int]) -> None: + self.label_model_dict = label_model_dict + + def __call__( + self, img: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor] + ) -> NdarrayOrTensor: + img_ = self.get_stacked_torch(img) + + has_ch_dim = False + if img_.ndimension() > 1 and img_.shape[1] > 1: + # convert multi-channel (One-Hot) images to argmax + img_ = torch.argmax(img_, dim=1, keepdim=True) + has_ch_dim = True + + # combining the tissues from the best performing models + out_pt = torch.empty(img_.size()[1:]) + for tissue_id, model_id in self.label_model_dict.items(): + temp_best_tissue = img_[model_id, ...] + out_pt[temp_best_tissue == tissue_id] = tissue_id + + if has_ch_dim: + # convert back to multi-channel (One-Hot) + num_classes = max(self.label_model_dict.keys()) + 1 + out_pt = one_hot(out_pt, num_classes, dim=0) + + return self.post_convert(out_pt, img) + + +class SelectBestEnsembled(Ensembled): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.SelectBestEnsemble`. + """ + + backend = SelectBestEnsemble.backend + + def __init__( + self, + label_model_dict: dict[int, int], + keys: KeysCollection, + output_key: Optional[str] = None, + ) -> None: + """ + Args: + keys: keys of the corresponding items to be stack and execute ensemble. + if only 1 key provided, suppose it's a PyTorch Tensor with data stacked on dimension `E`. + output_key: the key to store ensemble result in the dictionary. + if only 1 key provided in `keys`, `output_key` can be None and use `keys` as default. + + """ + ensemble = SelectBestEnsemble( + label_model_dict=label_model_dict, + ) + super().__init__(keys, ensemble, output_key) diff --git a/src/segmantic/seg/losses.py b/src/segmantic/seg/losses.py new file mode 100644 index 0000000..5c7c21c --- /dev/null +++ b/src/segmantic/seg/losses.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import torch +from monai.networks import one_hot +from monai.utils import LossReduction +from torch.nn.modules.loss import _Loss + + +class AsymmetricUnifiedFocalLoss(_Loss): + """ + AsymmetricUnifiedFocalLoss is a variant of Focal Loss. + + Implementation of the Asymmetric Unified Focal Loss described in: + + - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", + Michael Yeung, Computerized Medical Imaging and Graphics + - https://github.com/mlyg/unified-focal-loss/issues/8 + """ + + def __init__( + self, + to_onehot_y: bool = False, + num_classes: int = 2, + gamma: float = 0.75, + delta: float = 0.7, + reduction: LossReduction | str = LossReduction.MEAN, + ): + """ + Args: + to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. + num_classes : number of classes. Defaults to 2. + delta : weight of the background. Defaults to 0.7. + gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. + + Example: + >>> import torch + >>> from monai.losses import AsymmetricUnifiedFocalLoss + >>> pred = torch.ones((1,1,32,32), dtype=torch.float32) + >>> grnd = torch.ones((1,1,32,32), dtype=torch.int64) + >>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True) + >>> fl(pred, grnd) + """ + super().__init__(reduction=LossReduction(reduction).value) + self.to_onehot_y = to_onehot_y + self.num_classes = num_classes + self.gamma = gamma + self.delta = delta + self.epsilon = torch.tensor(1e-7) + + def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: + """ + Args: + y_pred : the shape should be BNH[WD], where N is the number of classes. + It only supports binary segmentation. + The input should be the original logits since it will be transformed by + a sigmoid in the forward function. + y_true : the shape should be BNH[WD], where N is the number of classes. + It only supports binary segmentation. + + Raises: + ValueError: When input and target are different shape + ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5 + ValueError: When num_classes + ValueError: When the number of classes entered does not match the expected number + """ + if y_pred.shape != y_true.shape: + raise ValueError( + f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})" + ) + + if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: + raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}") + + if y_pred.shape[1] == 1: + y_pred = one_hot(y_pred, num_classes=self.num_classes) + y_true = one_hot(y_true, num_classes=self.num_classes) + + if torch.max(y_true) != self.num_classes - 1: + raise ValueError( + f"Please make sure the number of classes is {self.num_classes-1}" + ) + + # use pytorch CrossEntropyLoss ? + # https://github.com/Project-MONAI/MONAI/blob/2d463a7d19166cff6a83a313f339228bc812912d/monai/losses/dice.py#L741 + epsilon = torch.tensor(self.epsilon, device=y_pred.device) + y_pred = torch.clip(y_pred, epsilon, 1.0 - epsilon) + cross_entropy = -y_true * torch.log(y_pred) + + # calculate losses separately for each class, only enhancing foreground class + back_ce = ( + torch.pow(1 - y_pred[:, 0, ...], self.gamma) * cross_entropy[:, 0, ...] + ) + back_ce = (1 - self.delta) * back_ce + + losses = [back_ce] + for i in range(1, self.num_classes): + i_ce = cross_entropy[:, i, ...] + i_ce = self.delta * i_ce + losses.append(i_ce) + + loss = torch.stack(losses) + if self.reduction == LossReduction.SUM.value: + return torch.sum(loss) # sum over the batch and channel dims + if self.reduction == LossReduction.NONE.value: + return loss # returns [N, num_classes] losses + if self.reduction == LossReduction.MEAN.value: + return torch.mean(loss) + raise ValueError( + f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].' + ) diff --git a/src/segmantic/seg/monai_unet.py b/src/segmantic/seg/monai_unet.py index 3a1049d..4e8ad04 100644 --- a/src/segmantic/seg/monai_unet.py +++ b/src/segmantic/seg/monai_unet.py @@ -2,6 +2,7 @@ import os import subprocess as sp import sys +import warnings from collections.abc import Sequence from functools import partial from pathlib import Path @@ -23,7 +24,7 @@ from monai.engines import EnsembleEvaluator from monai.inferers import SlidingWindowInferer, sliding_window_inference from monai.losses import DiceLoss -from monai.metrics import ConfusionMatrixMetric, CumulativeAverage, DiceMetric +from monai.metrics import CumulativeAverage, DiceMetric from monai.networks.layers.factories import Norm from monai.networks.nets import UNet from monai.networks.utils import one_hot @@ -32,8 +33,10 @@ AsDiscreted, Compose, CropForegroundd, + DeleteItemsd, EnsureType, EnsureTyped, + ForegroundMaskd, Invertd, LoadImaged, MeanEnsembled, @@ -63,8 +66,9 @@ from pytorch_lightning.loggers import TensorBoardLogger from ..image.labels import load_decathlon_tissuelist, load_tissue_list +from ..seg.ensemble import SelectBestEnsembled from ..seg.enum import EnsembleCombination -from ..seg.transforms import SelectBestEnsembled +from ..seg.transforms import SelectChanneld from ..utils import config from .dataset import PairedDataSet from .evaluation import confusion_matrix @@ -83,10 +87,7 @@ class Net(pl.LightningModule): optimizer: dict = { "optimizer": "Adam", "lr": 1e-4, - "momentum": 0.9, - "epsilon": 1e-8, "amsgrad": False, - "weight_decouple": False, } lr_scheduling: dict = { "scheduler": "Constant", @@ -106,6 +107,7 @@ def __init__( strides: tuple[int, ...] = (2, 2, 2, 2), dropout: float = 0.0, act: str = "PRELU", + threshold_foreground: bool = False, ): super().__init__() @@ -136,9 +138,15 @@ def __init__( self.dice_metric = DiceMetric( include_background=False, reduction="mean", get_not_nans=False ) + self.threshold_foreground = threshold_foreground self.best_val_dice = 0.0 self.best_val_epoch = 0.0 self.validation_step_outputs: list[dict] = [] + self.training_step_outputs: list[torch.Tensor] = [] + + @property + def num_channels(self): + return self._model.in_channels @property def num_classes(self): @@ -161,12 +169,33 @@ def default_preprocessing( ensure_channel_first=True, ), Orientationd(keys=keys, axcodes="RAS"), - NormalizeIntensityd(keys="image", nonzero=False, channel_wise=True), + ] + + threshold_foreground = self.threshold_foreground or "label" not in keys + threshold_key = "image" + if threshold_foreground: + if self.num_channels > 0: + threshold_key = "maskimage_0" + xforms += [ + SelectChanneld(keys="image", new_key_postfix="_0"), + ForegroundMaskd(keys="image_0", invert=True, new_key_prefix="mask"), + DeleteItemsd(keys="image_0"), + ] + else: + threshold_key = "maskimage" + xforms.append( + ForegroundMaskd(keys="image", invert=True, new_key_prefix="mask") + ) + + xforms += [ CropForegroundd( keys=keys, - source_key="label" if "label" in keys else "image", + source_key=threshold_key if threshold_foreground else "label", + margin=0, allow_smaller=False, ), + DeleteItemsd(keys=threshold_key), + NormalizeIntensityd(keys="image", nonzero=False, channel_wise=True), EnsureTyped(keys=keys, dtype=torch.float32, device=self.device), # type: ignore ] @@ -226,7 +255,7 @@ def prepare_data(self) -> None: raise RuntimeError("The dataset is not set") # set deterministic training for reproducibility - set_determinism(seed=0) + set_determinism(seed=42) # define the data transforms preprocessing = None @@ -300,7 +329,7 @@ def configure_optimizers(self): optimizer = torch.optim.Adam( self._model.parameters(), lr=self.optimizer["lr"], - amsgrad=self.optimizer["amsgrad"], + amsgrad=self.optimizer.get("amsgrad"), ) elif self.optimizer["optimizer"] == "AdaBelief": optimizer = AdaBelief( @@ -344,9 +373,16 @@ def training_step(self, batch, batch_idx): loss = self.loss_function(output, labels) self.manual_backward(loss) optimizer.step() + self.training_step_outputs.append(loss) tensorboard_logs = {"train_loss": loss.item()} return {"loss": loss, "log": tensorboard_logs} + def on_train_epoch_end(self): + # do something with all training_step outputs, for example: + self.train_loss = torch.stack(self.training_step_outputs).mean() + self.log("train_loss", self.train_loss) + self.training_step_outputs.clear() + def validation_step(self, batch, batch_idx): images, labels = batch["image"], batch["label"] roi_size = tuple(160 for _ in range(self.spatial_dims)) @@ -360,7 +396,7 @@ def validation_step(self, batch, batch_idx): self.dice_metric(y_pred=outputs, y=labels) d = {"val_loss": loss, "val_number": len(outputs)} self.validation_step_outputs.append(d) - return {"val_loss": loss, "val_number": len(outputs)} + return d def on_validation_epoch_end(self): val_loss, num_items = 0, 0 @@ -412,6 +448,7 @@ def train( augmentation: dict = {}, augment_intensity: bool = False, augment_spatial: bool = False, + threshold_foreground: bool = False, channels: tuple[int, ...] = (16, 32, 64, 128, 256), strides: tuple[int, ...] = (2, 2, 2, 2), dropout: float = 0.0, @@ -430,10 +467,7 @@ def train( optimizer = { "optimizer": "Adam", "lr": 1e-4, - "momentum": 0.9, - "epsilon": 1e-8, "amsgrad": False, - "weight_decouple": False, } if lr_scheduling is None: lr_scheduling = { @@ -446,7 +480,11 @@ def train( # initialise the LightningModule if checkpoint_file and Path(checkpoint_file).exists(): - net: Net = Net.load_from_checkpoint(f"{checkpoint_file}", map_location="cpu") + net: Net = Net.load_from_checkpoint( + f"{checkpoint_file}", + map_location="cpu", + threshold_foreground=threshold_foreground, + ) net.best_val_dice = 0.0 else: if num_classes > 0 and tissue_list: @@ -474,6 +512,7 @@ def train( strides=strides, dropout=dropout, act=act, + threshold_foreground=threshold_foreground, ) if image_dir and labels_dir: net.dataset = PairedDataSet(image_dir=image_dir, labels_dir=labels_dir) @@ -505,7 +544,8 @@ def train( monitor="val_dice", mode="max", dirpath=output_dir if output_dir else log_dir, - save_top_k=3, + save_top_k=5, + save_last=True, ) # defining early stopping. When val loss improves less than 0 over 30 epochs, the training will be stopped. @@ -551,8 +591,8 @@ def train( def predict( model_file: Path, test_images: list[Path], + output_dir: Path, test_labels: Optional[list[Path]] = None, - output_dir: Path = None, tissue_dict: dict[str, int] = None, channels: tuple[int, ...] = (16, 32, 64, 128, 256), strides: tuple[int, ...] = (2, 2, 2, 2), @@ -642,8 +682,6 @@ def predict( dice_metric = DiceMetric( include_background=False, reduction="mean", get_not_nans=False ) - confusion_metrics = ["sensitivity", "specificity", "precision", "accuracy"] - conf_matrix = ConfusionMatrixMetric(metric_name=confusion_metrics) mean_class_dice = CumulativeAverage() def to_one_hot(x): @@ -653,75 +691,70 @@ def to_one_hot(x): if tissue_dict: for name in tissue_dict.keys(): idx = tissue_dict[name] - tissue_names[idx] = name - - def print_table(header, vals, indent="\t"): - print(indent + "\t".join(header).expandtabs(30)) - print(indent + "\t".join(f"{x}" for x in vals).expandtabs(30)) + tissue_names[idx] = name.strip() + confusion = None all_mean_dice = [] with torch.no_grad(): - for test_data in test_loader: - val_pred = inferer(test_data["image"].to(device), net) - assert isinstance(val_pred, torch.Tensor) - - test_data["pred"] = val_pred - for i in decollate_batch(test_data): - post_transforms(i) + with open(output_dir / "_eval_dice.csv", "w") as eval_fp: + print(",".join(["Case"] + tissue_names), file=eval_fp) + + for idx, test_data in enumerate(test_loader): + val_pred = inferer(test_data["image"].to(device), net) + assert isinstance(val_pred, torch.Tensor) + + test_data["pred"] = val_pred + for i in decollate_batch(test_data): + post_transforms(i) + + if test_labels: + filename = test_data["image_meta_dict"]["filename_or_obj"] + if filename and isinstance(filename, list): + filename = filename[0] + name = ( + Path(filename).stem.replace(".nii", "") + if filename + else f"{idx}" + ) - if test_labels: - val_pred = val_pred.argmax(dim=1, keepdim=True) - val_labels = test_data["label"].to(device).long() + val_pred = val_pred.argmax(dim=1, keepdim=True) + val_labels = test_data["label"].to(device).long() - dice: torch.Tensor = dice_metric( # type: ignore [assignment] - y_pred=to_one_hot(val_pred), y=to_one_hot(val_labels) - ) - mean_class_dice.append(dice) - conf_matrix(y_pred=to_one_hot(val_pred), y=to_one_hot(val_labels)) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + dice: torch.Tensor = dice_metric( # type: ignore [assignment] + y_pred=to_one_hot(val_pred), y=to_one_hot(val_labels) + ) + mean_class_dice.append(dice) - dice_np = dice.cpu().numpy() - print("Mean Dice: ", np.mean(dice_np)) - print("Class Dice:") - print_table(tissue_names, np.squeeze(dice_np)) + dice_np = dice.cpu().numpy() + print( + f"[{idx+1}/{len(test_loader)}] {name}, mean Dice: ", + np.mean(dice_np), + ) - all_mean_dice.append(dice_metric.aggregate().item()) # type: ignore + class_dice = [str(v) for v in np.squeeze(dice_np).tolist()] + print(", ".join([name] + class_dice), file=eval_fp) - filename_or_obj = test_data["image_meta_dict"]["filename_or_obj"] - if filename_or_obj and isinstance(filename_or_obj, list): - filename_or_obj = filename_or_obj[0] + all_mean_dice.append(dice_metric.aggregate().item()) # type: ignore - if output_dir and filename_or_obj: - base = Path(filename_or_obj).stem.replace(".nii", "") c = confusion_matrix( num_classes=num_classes, y_pred=val_pred.view(-1).cpu().numpy(), y=val_labels.view(-1).cpu().numpy(), ) - plot_confusion_matrix( - c, - tissue_names, - file_name=output_dir / (base + "_confusion.png"), - ) - if output_dir is None: - print("No output path specified, dice scores won't be saved.") - else: - np.savetxt( - output_dir / f"mean_dice_{model_file.stem}_generalized_score.txt", - all_mean_dice, - delimiter=",", - ) + confusion = c if confusion is None else (confusion + c) + np.savetxt( + output_dir / f"_eval_mean_dice_{model_file.stem}.txt", + all_mean_dice, + delimiter=",", + ) - if test_labels: - print("*" * 80) - print("Total Mean Dice: ", dice_metric.aggregate().item()) # type: ignore - print("Total Class Dice:") - print_table( - tissue_names, np.squeeze(mean_class_dice.aggregate().cpu().numpy()) # type: ignore - ) - print("Total Conf. Matrix Metrics:") - print_table( - confusion_metrics, - (np.squeeze(x.cpu().numpy()) for x in conf_matrix.aggregate()), # type: ignore + if test_labels and confusion is not None: + plot_confusion_matrix( + confusion, + tissue_names, + file_name=output_dir / ("_eval_confusion.png"), ) diff --git a/src/segmantic/seg/transforms.py b/src/segmantic/seg/transforms.py index 103dcb3..014c622 100644 --- a/src/segmantic/seg/transforms.py +++ b/src/segmantic/seg/transforms.py @@ -1,93 +1,14 @@ -from collections.abc import Hashable, Mapping, Sequence -from typing import Optional, Union +from collections.abc import Hashable, Mapping import torch from monai.config import KeysCollection from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor from monai.data.meta_obj import get_track_meta -from monai.networks.utils import one_hot -from monai.transforms.post.array import Ensemble -from monai.transforms.post.dictionary import Ensembled +from monai.data.meta_tensor import MetaTensor from monai.transforms.transform import MapTransform, Transform from monai.utils import TransformBackends, convert_to_dst_type, convert_to_tensor -class SelectBestEnsemble(Ensemble, Transform): - """ - Execute select best ensemble on the input data. - The input data can be a list or tuple of PyTorch Tensor with shape: [C[, H, W, D]], - Or a single PyTorch Tensor with shape: [E[, C, H, W, D]], the `E` dimension represents - the output data from different models. - Typically, the input data is model output of segmentation task or classification task. - - Note: - This select best transform expects the input data is discrete single channel values. - It selects the tissue of the model which performed best in a generalization analysis. - The mapping is saved in the label_model_dict. - The output data has the same shape as every item of the input data. - - Args: - label_model_dict: dictionary containing the best models index for each tissue and - the tissue labels. - """ - - backend = [TransformBackends.TORCH] - - def __init__(self, label_model_dict: dict[int, int]) -> None: - self.label_model_dict = label_model_dict - - def __call__( - self, img: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor] - ) -> NdarrayOrTensor: - img_ = self.get_stacked_torch(img) - - has_ch_dim = False - if img_.ndimension() > 1 and img_.shape[1] > 1: - # convert multi-channel (One-Hot) images to argmax - img_ = torch.argmax(img_, dim=1, keepdim=True) - has_ch_dim = True - - # combining the tissues from the best performing models - out_pt = torch.empty(img_.size()[1:]) - for tissue_id, model_id in self.label_model_dict.items(): - temp_best_tissue = img_[model_id, ...] - out_pt[temp_best_tissue == tissue_id] = tissue_id - - if has_ch_dim: - # convert back to multi-channel (One-Hot) - num_classes = max(self.label_model_dict.keys()) + 1 - out_pt = one_hot(out_pt, num_classes, dim=0) - - return self.post_convert(out_pt, img) - - -class SelectBestEnsembled(Ensembled): - """ - Dictionary-based wrapper of :py:class:`monai.transforms.SelectBestEnsemble`. - """ - - backend = SelectBestEnsemble.backend - - def __init__( - self, - label_model_dict: dict[int, int], - keys: KeysCollection, - output_key: Optional[str] = None, - ) -> None: - """ - Args: - keys: keys of the corresponding items to be stack and execute ensemble. - if only 1 key provided, suppose it's a PyTorch Tensor with data stacked on dimension `E`. - output_key: the key to store ensemble result in the dictionary. - if only 1 key provided in `keys`, `output_key` can be None and use `keys` as default. - - """ - ensemble = SelectBestEnsemble( - label_model_dict=label_model_dict, - ) - super().__init__(keys, ensemble, output_key) - - class MapLabels(Transform): """ """ @@ -125,3 +46,54 @@ def __call__( for key in self.key_iterator(d): d[key] = self.converter(d[key]) return d + + +class SelectChannel(Transform): + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self, channel_dim: int = 0, channel: int = 0) -> None: + self.channel_dim = channel_dim + self.channel = channel + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + """ + Apply the transform to `img` + """ + if isinstance(img, torch.Tensor): + output = img.index_select( + index=torch.tensor(self.channel, dtype=torch.long, device=img.device), + dim=self.channel_dim, + ) + else: + output = img.take(indices=self.channel, axis=self.channel_dim) + + if isinstance(img, MetaTensor) and not isinstance(output, MetaTensor): + output = MetaTensor(output, meta=img.meta) + + return output + + +class SelectChanneld(MapTransform): + + backend = SelectChannel.backend + + def __init__( + self, + keys: KeysCollection, + channel_dim: int = 0, + channel: int = 0, + new_key_postfix: str = "_sel", + allow_missing_keys: bool = False, + ): + super().__init__(keys, allow_missing_keys) + self.postfix = new_key_postfix + self.op = SelectChannel(channel_dim, channel) + + def __call__( + self, data: Mapping[Hashable, NdarrayOrTensor] + ) -> dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[str(key) + self.postfix] = self.op(d[key]) + return d diff --git a/src/segmantic/seg/visualization.py b/src/segmantic/seg/visualization.py index 4368f5d..e1667c3 100644 --- a/src/segmantic/seg/visualization.py +++ b/src/segmantic/seg/visualization.py @@ -80,7 +80,7 @@ def plot_confusion_matrix( target_names = y_labels_vals, # list of names of the classes title = best_estimator_name) # title of graph - Citiation + Citation --------- http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html @@ -93,7 +93,8 @@ def plot_confusion_matrix( cmap = plt.get_cmap("Blues") if normalize: - cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] + with np.errstate(divide="ignore", invalid="ignore"): + cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] fig = plt.figure(figsize=(16, 16)) plt.imshow(cm, interpolation="nearest", cmap=cmap) @@ -111,7 +112,7 @@ def plot_confusion_matrix( plt.text( j, i, - f"{cm[i, j]:0.4f}", + f"{cm[i, j]:0.2f}", horizontalalignment="center", color="white" if cm[i, j] > thresh else "black", ) diff --git a/tests/seg/test_MapLabels.py b/tests/seg/test_MapLabels.py deleted file mode 100644 index fcf92b8..0000000 --- a/tests/seg/test_MapLabels.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch -from monai.bundle import ConfigParser - -from segmantic.seg.transforms import MapLabels - - -def test_MapLabels(): - labels = torch.tensor([2, 1, 2, 0]).reshape(1, 4, 1, 1) - - mapper = MapLabels({1: 3, 2: 1, 0: 0}) - labels_mapped = mapper(labels) - assert (labels_mapped == torch.tensor([1, 3, 1, 0]).reshape(1, 4, 1, 1)).all() - - -def test_Bundle_MapLabels(): - parser = ConfigParser( - { - "imports": ["$import segmantic"], - "mapping": "${1: 3, 2: 1, 0: 0}", - "postpro": "$segmantic.seg.transforms.MapLabels(@mapping)", - } - ) - parser.parse(True) - mapper = parser.get_parsed_content("postpro") - print(mapper) - assert isinstance(mapper, MapLabels) - - labels = torch.tensor([2, 1, 2, 0]).reshape(1, 4, 1, 1) - labels_mapped = mapper(labels) - assert (labels_mapped == torch.tensor([1, 3, 1, 0]).reshape(1, 4, 1, 1)).all() - - -if __name__ == "__main__": - test_Bundle_MapLabels() diff --git a/tests/seg/test_ensemble.py b/tests/seg/test_ensemble.py new file mode 100644 index 0000000..944fae2 --- /dev/null +++ b/tests/seg/test_ensemble.py @@ -0,0 +1,43 @@ +import torch +from monai.networks import one_hot +from torch.testing import assert_close + +from segmantic.seg.ensemble import SelectBestEnsembled + + +def test_SelectBestEnsembled(): + # label (e.g. "argmax") data + input = { + "pred0": torch.ones(1, 3, 1, 1), + "pred1": torch.tensor([2, 0, 2]).reshape(1, 3, 1, 1), + "pred2": torch.tensor([2, 1, 0]).reshape(1, 3, 1, 1), + } + expected_value = torch.tensor([2, 1, 0], dtype=torch.float32).reshape(1, 3, 1, 1) + if torch.cuda.is_available(): + for k in input.keys(): + input[k] = input[k].to(torch.device("cuda:0")) + expected_value = expected_value.to(torch.device("cuda:0")) + + tr = SelectBestEnsembled( + keys=["pred0", "pred1", "pred2"], + output_key="output", + label_model_dict={1: 0, 2: 1, 0: 2}, + ) + result = tr(input) + assert_close(result["output"], expected_value) + + # One-Hot data + input = {k: one_hot(i, num_classes=3, dim=0) for k, i in input.items()} + expected_value = one_hot(expected_value, num_classes=3, dim=0) + + tr = SelectBestEnsembled( + keys=["pred0", "pred1", "pred2"], + output_key="output", + label_model_dict={1: 0, 2: 1, 0: 2}, + ) + result = tr(input) + assert_close(result["output"], expected_value) + + +if __name__ == "__main__": + test_SelectBestEnsembled() diff --git a/tests/seg/test_losses.py b/tests/seg/test_losses.py new file mode 100644 index 0000000..ceb558b --- /dev/null +++ b/tests/seg/test_losses.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import numpy as np +import pytest +import torch + +from segmantic.seg.losses import AsymmetricUnifiedFocalLoss + +TEST_CASES = [ + ( # shape: (2, 1, 2, 2), (2, 1, 2, 2) + { + "y_pred": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]), + "y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]), + }, + 0.0, + ), + ( # shape: (2, 1, 2, 2), (2, 1, 2, 2) + { + "y_pred": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]), + "y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]), + }, + 0.0, + ), +] + + +@pytest.mark.parametrize("input_data,expected_val", TEST_CASES) +def test_result(input_data, expected_val): + loss = AsymmetricUnifiedFocalLoss() + result = loss(**input_data) + np.testing.assert_allclose( + result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4 + ) + + +def test_ill_shape(): + loss = AsymmetricUnifiedFocalLoss() + with pytest.raises(ValueError): + loss(torch.ones((2, 2, 2)), torch.ones((2, 2, 2, 2))) + + +def test_with_cuda(): + loss = AsymmetricUnifiedFocalLoss() + i = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]) + j = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]) + if torch.cuda.is_available(): + i = i.cuda() + j = j.cuda() + output = loss(i, j) + np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + test_with_cuda() diff --git a/tests/seg/test_transforms.py b/tests/seg/test_transforms.py index 8c0b881..88ae1e3 100644 --- a/tests/seg/test_transforms.py +++ b/tests/seg/test_transforms.py @@ -1,43 +1,67 @@ +import numpy as np import torch -from monai.networks import one_hot -from torch.testing import assert_close - -from segmantic.seg.transforms import SelectBestEnsembled - - -def test_SelectBestEnsembled(): - # label (e.g. "argmax") data - input = { - "pred0": torch.ones(1, 3, 1, 1), - "pred1": torch.tensor([2, 0, 2]).reshape(1, 3, 1, 1), - "pred2": torch.tensor([2, 1, 0]).reshape(1, 3, 1, 1), - } - expected_value = torch.tensor([2, 1, 0], dtype=torch.float32).reshape(1, 3, 1, 1) - if torch.cuda.is_available(): - for k in input.keys(): - input[k] = input[k].to(torch.device("cuda:0")) - expected_value = expected_value.to(torch.device("cuda:0")) - - tr = SelectBestEnsembled( - keys=["pred0", "pred1", "pred2"], - output_key="output", - label_model_dict={1: 0, 2: 1, 0: 2}, +from monai.bundle import ConfigParser + +from segmantic.seg.transforms import MapLabels, SelectChannel, SelectChanneld + + +def test_MapLabels(): + labels = torch.tensor([2, 1, 2, 0]).reshape(1, 4, 1, 1) + + mapper = MapLabels({1: 3, 2: 1, 0: 0}) + labels_mapped = mapper(labels) + assert (labels_mapped == torch.tensor([1, 3, 1, 0]).reshape(1, 4, 1, 1)).all() + + +def test_Bundle_MapLabels(): + parser = ConfigParser( + { + "imports": ["$import segmantic"], + "mapping": "${1: 3, 2: 1, 0: 0}", + "postpro": "$segmantic.seg.transforms.MapLabels(@mapping)", + } ) - result = tr(input) - assert_close(result["output"], expected_value) + parser.parse(True) + mapper = parser.get_parsed_content("postpro") + print(mapper) + assert isinstance(mapper, MapLabels) + + labels = torch.tensor([2, 1, 2, 0]).reshape(1, 4, 1, 1) + labels_mapped = mapper(labels) + assert (labels_mapped == torch.tensor([1, 3, 1, 0]).reshape(1, 4, 1, 1)).all() + + +def test_SelectChannel(): + for asarray in (np.asarray, torch.tensor): + img = asarray([2, 1, 2, 0, 5, 6, 7, 3]).reshape(2, 4, 1, 1) + + select = SelectChannel(channel_dim=0, channel=0) + img_channel0 = select(img) + assert (img_channel0 == asarray([2, 1, 2, 0]).reshape(1, 4, 1, 1)).all() + + +def test_SelectChannel_single_channel_input(): + for asarray in (np.asarray, torch.tensor): + img = asarray([2, 1, 2, 0]).reshape(1, 4, 1, 1) + + select = SelectChannel(channel_dim=0, channel=0) + img_channel0 = select(img) + assert (img_channel0 == img).all() + - # One-Hot data - input = {k: one_hot(i, num_classes=3, dim=0) for k, i in input.items()} - expected_value = one_hot(expected_value, num_classes=3, dim=0) +def test_SelectChanneld(): + img = torch.tensor([2, 1, 2, 0, 5, 6, 7, 3]).reshape(2, 4, 1, 1) - tr = SelectBestEnsembled( - keys=["pred0", "pred1", "pred2"], - output_key="output", - label_model_dict={1: 0, 2: 1, 0: 2}, + select = SelectChanneld( + keys="img", channel_dim=0, channel=0, new_key_postfix="_sel" ) - result = tr(input) - assert_close(result["output"], expected_value) + img_channel0 = select({"img": img}) + assert isinstance(img_channel0, dict) + assert "img_sel" in img_channel0 + assert ( + img_channel0["img_sel"] == torch.tensor([2, 1, 2, 0]).reshape(1, 4, 1, 1) + ).all() if __name__ == "__main__": - test_SelectBestEnsembled() + test_SelectChannel_single_channel_input()