-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
✨♻️ refactor preprocessing #64
Open
dyollb
wants to merge
4
commits into
main
Choose a base branch
from
refactor_preprocessing
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from pathlib import Path | ||
|
||
import typer | ||
from monai.transforms import ( | ||
Compose, | ||
CropForegroundd, | ||
ForegroundMaskd, | ||
LoadImaged, | ||
NormalizeIntensityd, | ||
SaveImaged, | ||
) | ||
|
||
from segmantic.seg.transforms import SelectChanneld | ||
|
||
|
||
def pre_process(input_file: Path, output_dir: Path, margin: int = 0, channel: int = 0): | ||
|
||
transforms = Compose( | ||
[ | ||
LoadImaged( | ||
keys="img", | ||
reader="ITKReader", | ||
image_only=False, | ||
ensure_channel_first=True, | ||
), | ||
SelectChanneld(keys="img", channel=channel, new_key_postfix="_0"), | ||
ForegroundMaskd(keys="img_0", invert=True, new_key_prefix="mask"), | ||
CropForegroundd( | ||
keys="img", source_key="maskimg_0", allow_smaller=False, margin=margin | ||
), | ||
NormalizeIntensityd(keys="img"), | ||
SaveImaged( | ||
keys="img", | ||
writer="ITKWriter", | ||
output_dir=output_dir, | ||
output_postfix="", | ||
resample=False, | ||
separate_folder=False, | ||
print_log=True, | ||
), | ||
] | ||
) | ||
transforms({"img": input_file}) | ||
|
||
|
||
if __name__ == "__main__": | ||
typer.run(pre_process) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
""" ML-based segmentation for medical images | ||
""" | ||
|
||
__version__ = "0.4.0" | ||
__version__ = "0.5.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from collections.abc import Sequence | ||
from typing import Optional, Union | ||
|
||
import torch | ||
from monai.config import KeysCollection | ||
from monai.config.type_definitions import NdarrayOrTensor | ||
from monai.networks.utils import one_hot | ||
from monai.transforms.post.array import Ensemble | ||
from monai.transforms.post.dictionary import Ensembled | ||
from monai.transforms.transform import Transform | ||
from monai.utils import TransformBackends | ||
|
||
|
||
class SelectBestEnsemble(Ensemble, Transform): | ||
""" | ||
Execute select best ensemble on the input data. | ||
The input data can be a list or tuple of PyTorch Tensor with shape: [C[, H, W, D]], | ||
Or a single PyTorch Tensor with shape: [E[, C, H, W, D]], the `E` dimension represents | ||
the output data from different models. | ||
Typically, the input data is model output of segmentation task or classification task. | ||
|
||
Note: | ||
This select best transform expects the input data is discrete single channel values. | ||
It selects the tissue of the model which performed best in a generalization analysis. | ||
The mapping is saved in the label_model_dict. | ||
The output data has the same shape as every item of the input data. | ||
|
||
Args: | ||
label_model_dict: dictionary containing the best models index for each tissue and | ||
the tissue labels. | ||
""" | ||
|
||
backend = [TransformBackends.TORCH] | ||
|
||
def __init__(self, label_model_dict: dict[int, int]) -> None: | ||
self.label_model_dict = label_model_dict | ||
|
||
def __call__( | ||
self, img: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor] | ||
) -> NdarrayOrTensor: | ||
img_ = self.get_stacked_torch(img) | ||
|
||
has_ch_dim = False | ||
if img_.ndimension() > 1 and img_.shape[1] > 1: | ||
# convert multi-channel (One-Hot) images to argmax | ||
img_ = torch.argmax(img_, dim=1, keepdim=True) | ||
has_ch_dim = True | ||
|
||
# combining the tissues from the best performing models | ||
out_pt = torch.empty(img_.size()[1:]) | ||
for tissue_id, model_id in self.label_model_dict.items(): | ||
temp_best_tissue = img_[model_id, ...] | ||
out_pt[temp_best_tissue == tissue_id] = tissue_id | ||
|
||
if has_ch_dim: | ||
# convert back to multi-channel (One-Hot) | ||
num_classes = max(self.label_model_dict.keys()) + 1 | ||
out_pt = one_hot(out_pt, num_classes, dim=0) | ||
|
||
return self.post_convert(out_pt, img) | ||
|
||
|
||
class SelectBestEnsembled(Ensembled): | ||
""" | ||
Dictionary-based wrapper of :py:class:`monai.transforms.SelectBestEnsemble`. | ||
""" | ||
|
||
backend = SelectBestEnsemble.backend | ||
|
||
def __init__( | ||
self, | ||
label_model_dict: dict[int, int], | ||
keys: KeysCollection, | ||
output_key: Optional[str] = None, | ||
) -> None: | ||
""" | ||
Args: | ||
keys: keys of the corresponding items to be stack and execute ensemble. | ||
if only 1 key provided, suppose it's a PyTorch Tensor with data stacked on dimension `E`. | ||
output_key: the key to store ensemble result in the dictionary. | ||
if only 1 key provided in `keys`, `output_key` can be None and use `keys` as default. | ||
|
||
""" | ||
ensemble = SelectBestEnsemble( | ||
label_model_dict=label_model_dict, | ||
) | ||
super().__init__(keys, ensemble, output_key) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
from __future__ import annotations | ||
|
||
import torch | ||
from monai.networks import one_hot | ||
from monai.utils import LossReduction | ||
from torch.nn.modules.loss import _Loss | ||
|
||
|
||
class AsymmetricUnifiedFocalLoss(_Loss): | ||
""" | ||
AsymmetricUnifiedFocalLoss is a variant of Focal Loss. | ||
|
||
Implementation of the Asymmetric Unified Focal Loss described in: | ||
|
||
- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", | ||
Michael Yeung, Computerized Medical Imaging and Graphics | ||
- https://github.com/mlyg/unified-focal-loss/issues/8 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
to_onehot_y: bool = False, | ||
num_classes: int = 2, | ||
gamma: float = 0.75, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. gamma=2.0 |
||
delta: float = 0.7, | ||
reduction: LossReduction | str = LossReduction.MEAN, | ||
): | ||
""" | ||
Args: | ||
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. | ||
num_classes : number of classes. Defaults to 2. | ||
delta : weight of the background. Defaults to 0.7. | ||
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. | ||
|
||
Example: | ||
>>> import torch | ||
>>> from monai.losses import AsymmetricUnifiedFocalLoss | ||
>>> pred = torch.ones((1,1,32,32), dtype=torch.float32) | ||
>>> grnd = torch.ones((1,1,32,32), dtype=torch.int64) | ||
>>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True) | ||
>>> fl(pred, grnd) | ||
""" | ||
super().__init__(reduction=LossReduction(reduction).value) | ||
self.to_onehot_y = to_onehot_y | ||
self.num_classes = num_classes | ||
self.gamma = gamma | ||
self.delta = delta | ||
self.epsilon = torch.tensor(1e-7) | ||
|
||
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Args: | ||
y_pred : the shape should be BNH[WD], where N is the number of classes. | ||
It only supports binary segmentation. | ||
The input should be the original logits since it will be transformed by | ||
a sigmoid in the forward function. | ||
y_true : the shape should be BNH[WD], where N is the number of classes. | ||
It only supports binary segmentation. | ||
|
||
Raises: | ||
ValueError: When input and target are different shape | ||
ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5 | ||
ValueError: When num_classes | ||
ValueError: When the number of classes entered does not match the expected number | ||
""" | ||
if y_pred.shape != y_true.shape: | ||
raise ValueError( | ||
f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})" | ||
) | ||
|
||
if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: | ||
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}") | ||
|
||
if y_pred.shape[1] == 1: | ||
y_pred = one_hot(y_pred, num_classes=self.num_classes) | ||
y_true = one_hot(y_true, num_classes=self.num_classes) | ||
|
||
if torch.max(y_true) != self.num_classes - 1: | ||
raise ValueError( | ||
f"Please make sure the number of classes is {self.num_classes-1}" | ||
) | ||
|
||
# use pytorch CrossEntropyLoss ? | ||
# https://github.com/Project-MONAI/MONAI/blob/2d463a7d19166cff6a83a313f339228bc812912d/monai/losses/dice.py#L741 | ||
epsilon = torch.tensor(self.epsilon, device=y_pred.device) | ||
y_pred = torch.clip(y_pred, epsilon, 1.0 - epsilon) | ||
cross_entropy = -y_true * torch.log(y_pred) | ||
|
||
# calculate losses separately for each class, only enhancing foreground class | ||
back_ce = ( | ||
torch.pow(1 - y_pred[:, 0, ...], self.gamma) * cross_entropy[:, 0, ...] | ||
) | ||
back_ce = (1 - self.delta) * back_ce | ||
|
||
losses = [back_ce] | ||
for i in range(1, self.num_classes): | ||
i_ce = cross_entropy[:, i, ...] | ||
i_ce = self.delta * i_ce | ||
losses.append(i_ce) | ||
|
||
loss = torch.stack(losses) | ||
if self.reduction == LossReduction.SUM.value: | ||
return torch.sum(loss) # sum over the batch and channel dims | ||
if self.reduction == LossReduction.NONE.value: | ||
return loss # returns [N, num_classes] losses | ||
if self.reduction == LossReduction.MEAN.value: | ||
return torch.mean(loss) | ||
raise ValueError( | ||
f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].' | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be AsymmetricFocalLoss:
see https://github.com/mlyg/unified-focal-loss/blob/411c9c5ce43b2ef847f0903d1b841512ad8d2eee/loss_functions.py#L193
to compute AsymmetricUnifiedFocalLoss we also need the AsymmetricTverskyLoss
see https://github.com/mlyg/unified-focal-loss/blob/411c9c5ce43b2ef847f0903d1b841512ad8d2eee/loss_functions.py#L351