diff --git a/minerva/__init__.py b/minerva/__init__.py index e69de29..7b1cdcd 100644 --- a/minerva/__init__.py +++ b/minerva/__init__.py @@ -0,0 +1,10 @@ + +import minerva +import minerva.analysis +import minerva.callbacks +import minerva.data +import minerva.losses +import minerva.models +import minerva.pipelines +import minerva.transforms +import minerva.utils diff --git a/minerva/analysis/metrics/transformed_metrics.py b/minerva/analysis/metrics/transformed_metrics.py new file mode 100644 index 0000000..e3d3e56 --- /dev/null +++ b/minerva/analysis/metrics/transformed_metrics.py @@ -0,0 +1,191 @@ +import warnings +from typing import Optional + +import torch +from torchmetrics import Metric + + +class CroppedMetric(Metric): + def __init__( + self, + target_h_size: int, + target_w_size: int, + metric: Metric, + dist_sync_on_step: bool = False, + ): + """ + Initializes a new instance of CroppedMetric. + + Parameters + ---------- + target_h_size: int + The target height size. + target_w_size: int + The target width size. + dist_sync_on_step: bool, optional + Whether to synchronize metric state across processes at each step. + Defaults to False. + """ + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.metric = metric + self.target_h_size = target_h_size + self.target_w_size = target_w_size + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Updates the metric state with the predictions and targets. + + Parameters + ---------- + preds: torch.Tensor + The predicted tensor. + target: + torch.Tensor The target tensor. + """ + + preds = self.crop(preds) + target = self.crop(target) + self.metric.update(preds, target) + + def compute(self) -> float: + """ + Computes the cropped metric. + + Returns: + float: The cropped metric. + """ + return self.metric.compute() + + def crop(self, x: torch.Tensor) -> torch.Tensor: + """crops the input tensor to the target size. + + Parameters + ---------- + x : torch.Tensor + The input tensor. + + Returns + ------- + torch.Tensor + The cropped tensor. + """ + h, w = x.shape[-2:] + start_h = (h - self.target_h_size) // 2 + start_w = (w - self.target_w_size) // 2 + end_h = start_h + self.target_h_size + end_w = start_w + self.target_w_size + + return x[..., start_h:end_h, start_w:end_w] + + +class ResizedMetric(Metric): + def __init__( + self, + target_h_size: Optional[int], + target_w_size: Optional[int], + metric: Metric, + keep_aspect_ratio: bool = False, + dist_sync_on_step: bool = False, + ): + """ + Initializes a new instance of ResizeMetric. + + Parameters + ---------- + target_h_size: int + The target height size. + target_w_size: int + The target width size. + dist_sync_on_step: bool, optional + Whether to synchronize metric state across processes at each step. + Defaults to False. + """ + super().__init__(dist_sync_on_step=dist_sync_on_step) + + if target_h_size is None and target_w_size is None: + raise ValueError( + "At least one of target_h_size or target_w_size must be provided." + ) + + if ( + target_h_size is not None and target_w_size is None + ) and keep_aspect_ratio is False: + warnings.warn( + "A target_w_size is not provided, but keep_aspect_ratio is set to False. keep_aspect_ratio will be set to True. If you want to resize the image to a specific width, please provide a target_w_size." + ) + keep_aspect_ratio = True + + if ( + target_w_size is not None and target_h_size is None + ) and keep_aspect_ratio is False: + warnings.warn( + "A target_h_size is not provided, but keep_aspect_ratio is set to False. keep_aspect_ratio will be set to True. If you want to resize the image to a specific height, please provide a target_h_size." + ) + keep_aspect_ratio = True + + self.metric = metric + self.target_h_size = target_h_size + self.target_w_size = target_w_size + self.keep_aspect_ratio = keep_aspect_ratio + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Updates the metric state with the predictions and targets. + + Parameters + ---------- + preds: torch.Tensor + The predicted tensor. + target: + torch.Tensor The target tensor. + """ + + preds = self.resize(preds) + target = self.resize(target) + self.metric.update(preds, target) + + def compute(self) -> float: + """ + Computes the resized metric. + + Returns: + float: The resized metric. + """ + return self.metric.compute() + + def resize(self, x: torch.Tensor) -> torch.Tensor: + """Resizes the input tensor to the target size. + + Parameters + ---------- + x : torch.Tensor + The input tensor. + + Returns + ------- + torch.Tensor + The resized tensor. + """ + h, w = x.shape[-2:] + + target_h_size = self.target_h_size + target_w_size = self.target_w_size + if self.keep_aspect_ratio: + if self.target_h_size is None: + scale = target_w_size / w + target_h_size = int(h * scale) + elif self.target_w_size is None: + scale = target_h_size / h + target_w_size = int(w * scale) + type_convert = False + if "LongTensor" in x.type(): + x = x.to(torch.uint8) + type_convert = True + + return ( + torch.nn.functional.interpolate(x, size=(target_h_size, target_w_size)) + if not type_convert + else torch.nn.functional.interpolate( + x, size=(target_h_size, target_w_size) + ).to(torch.long) + ) diff --git a/minerva/callbacks/HyperSearchCallbacks.py b/minerva/callbacks/HyperSearchCallbacks.py new file mode 100644 index 0000000..e24e790 --- /dev/null +++ b/minerva/callbacks/HyperSearchCallbacks.py @@ -0,0 +1,108 @@ +import os +import shutil +import tempfile +from pathlib import Path + +import lightning.pytorch as L +from ray import train +from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag +from ray.train import Checkpoint + + +class TrainerReportOnIntervalCallback(L.Callback): + + CHECKPOINT_NAME = "checkpoint.ckpt" + + def __init__(self, interval: int = 1) -> None: + super().__init__() + self.trial_name = train.get_context().get_trial_name() + self.local_rank = train.get_context().get_local_rank() + self.tmpdir_prefix = Path(tempfile.gettempdir(), self.trial_name).as_posix() + self.interval = interval + self.step = 0 + if os.path.isdir(self.tmpdir_prefix) and self.local_rank == 0: + shutil.rmtree(self.tmpdir_prefix) + + record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_RAYTRAINREPORTCALLBACK, "1") + + def on_train_epoch_end( + self, trainer: L.Trainer, pl_module: L.LightningModule + ) -> None: + + # Fetch metrics + metrics = trainer.callback_metrics + metrics = {k: v.item() for k, v in metrics.items()} + + # (Optional) Add customized metrics + metrics["epoch"] = trainer.current_epoch + metrics["step"] = trainer.global_step + + tmpdir = Path(self.tmpdir_prefix, str(trainer.current_epoch)).as_posix() + os.makedirs(tmpdir, exist_ok=True) + + if self.step % self.interval == 0: + + # Save checkpoint to local + ckpt_path = Path(tmpdir, self.CHECKPOINT_NAME).as_posix() + trainer.save_checkpoint(ckpt_path, weights_only=False) + + # Report to train session + checkpoint = Checkpoint.from_directory(tmpdir) + train.report(metrics=metrics, checkpoint=checkpoint) + else: + train.report(metrics=metrics) + + # Add a barrier to ensure all workers finished reporting here + trainer.strategy.barrier() + + if self.local_rank == 0: + shutil.rmtree(tmpdir) + + self.step += 1 + + +class TrainerReportKeepOnlyLastCallback(L.Callback): + + CHECKPOINT_NAME = "checkpoint.ckpt" + + def __init__(self) -> None: + super().__init__() + self.trial_name = train.get_context().get_trial_name() + self.local_rank = train.get_context().get_local_rank() + self.tmpdir_prefix = Path(tempfile.gettempdir(), self.trial_name).as_posix() + if os.path.isdir(self.tmpdir_prefix) and self.local_rank == 0: + shutil.rmtree(self.tmpdir_prefix) + + record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_RAYTRAINREPORTCALLBACK, "1") + + def on_train_epoch_end( + self, trainer: L.Trainer, pl_module: L.LightningModule + ) -> None: + # Fetch metrics + metrics = trainer.callback_metrics + metrics = {k: v.item() for k, v in metrics.items()} + + # (Optional) Add customized metrics + metrics["epoch"] = trainer.current_epoch + metrics["step"] = trainer.global_step + + tmpdir = Path(self.tmpdir_prefix, "last").as_posix() + os.makedirs(tmpdir, exist_ok=True) + + # Delete previous checkpoint + if os.path.isdir(tmpdir): + shutil.rmtree(tmpdir) + + # Save checkpoint to local + ckpt_path = Path(tmpdir, self.CHECKPOINT_NAME).as_posix() + trainer.save_checkpoint(ckpt_path, weights_only=False) + + # Report to train session + checkpoint = Checkpoint.from_directory(tmpdir) + train.report(metrics=metrics, checkpoint=checkpoint) + + # Add a barrier to ensure all workers finished reporting here + trainer.strategy.barrier() + + if self.local_rank == 0: + shutil.rmtree(tmpdir) diff --git a/minerva/callbacks/__init__.py b/minerva/callbacks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/minerva/data/datasets/supervised_dataset.py b/minerva/data/datasets/supervised_dataset.py index cccf98e..7160077 100644 --- a/minerva/data/datasets/supervised_dataset.py +++ b/minerva/data/datasets/supervised_dataset.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import Any, List, Optional, Tuple import numpy as np @@ -15,7 +15,7 @@ class SupervisedReconstructionDataset(SimpleDataset): Usually, both input and target data have the same shape. This dataset is useful for supervised tasks such as image reconstruction, - segmantic segmentation, and object detection, where the input data is the + semantic segmentation, and object detection, where the input data is the original data and the target is a mask or a segmentation map. Examples @@ -45,7 +45,12 @@ class SupervisedReconstructionDataset(SimpleDataset): ``` """ - def __init__(self, readers: List[_Reader], transforms: Optional[_Transform] = None): + def __init__( + self, + readers: List[_Reader], + transforms: Optional[_Transform] = None, + support_context_transforms: bool = False, + ): """A simple dataset class for supervised reconstruction tasks. Parameters @@ -62,12 +67,13 @@ def __init__(self, readers: List[_Reader], transforms: Optional[_Transform] = No AssertionError: If the number of readers is not exactly 2. """ super().__init__(readers, transforms) + self.support_context_transforms = support_context_transforms assert ( len(self.readers) == 2 ), "SupervisedReconstructionDataset requires exactly 2 readers" - def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]: + def __getitem__(self, index: int) -> Tuple[Any, Any]: """Load data from sources and apply specified transforms. The same transform is applied to both input and target data. @@ -78,10 +84,29 @@ def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]: Returns ------- - Tuple[np.ndarray, np.ndarray] - A tuple containing two numpy arrays representing the data. + Tuple[Any, Any] + A tuple containing two elements: the input data and the target data. """ - data = super().__getitem__(index) - - return (data[0], data[1]) + if not self.support_context_transforms: + data = super().__getitem__(index) + + return (data[0], data[1]) + else: + + data = [] + + # For each reader and transform, read the data and apply the transform. + # Then, append the transformed data to the list of data. + for reader, transform in zip(reversed(self.readers), self.transforms): + sample = reader[index] + # Apply the transform if it is not None + if transform is not None: + sample = transform(sample) + data.append(sample) + # Return the list of transformed data or a single sample if return_single + # is True and there is only one reader. + if self.return_single: + return data[1] + else: + return tuple(reversed(data)) diff --git a/minerva/engines/engine.py b/minerva/engines/engine.py new file mode 100644 index 0000000..60ac7dc --- /dev/null +++ b/minerva/engines/engine.py @@ -0,0 +1,17 @@ +from typing import Any, Union + +import lightning.pytorch as L +import numpy as np +import torch + + +class _Engine: + def __init__(self) -> None: + super().__init__() + + def __call__( + self, + model: Union[L.LightningModule, torch.nn.Module], + x: Union[torch.Tensor, np.ndarray], + ): + raise NotImplementedError diff --git a/minerva/engines/patch_inferencer_engine.py b/minerva/engines/patch_inferencer_engine.py index f88147c..74185b5 100644 --- a/minerva/engines/patch_inferencer_engine.py +++ b/minerva/engines/patch_inferencer_engine.py @@ -1,10 +1,14 @@ -from typing import List, Tuple, Optional, Dict, Any -import torch -import numpy as np +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + import lightning as L +import numpy as np +import torch +import torch.nn as nn + +from minerva.engines.engine import _Engine -class BasePatchInferencer: +class PatchInferencer(L.LightningModule): """Inference in patches for models This class provides utility methods for performing inference in patches @@ -15,9 +19,10 @@ def __init__( model: L.LightningModule, input_shape: Tuple, output_shape: Optional[Tuple] = None, - weight_function: Optional[callable] = None, + weight_function: Optional[Callable] = None, offsets: Optional[List[Tuple]] = None, padding: Optional[Dict[str, Any]] = None, + return_tuple: Optional[int] = None, ): """Initialize the patch inference auxiliary class @@ -37,12 +42,116 @@ def __init__( padding : Dict[str, Any], optional Dictionary describing padding strategy. Keys: pad: tuple with pad width (int) for each dimension, e.g. (0, 3, 3) when working with a tensor with 3 dimensions - mode (optional): 'constant', 'reflect', 'replicate' or 'cicular'. Defaults to 'constant'. - value (optional): fill value for 'constante'. Defaults to 0. + mode (optional): 'constant', 'reflect', 'replicate' or 'circular'. Defaults to 'constant'. + value (optional): fill value for 'constant'. Defaults to 0. """ + super().__init__() self.model = model - self.input_shape = input_shape - self.output_shape = output_shape if output_shape is not None else input_shape + self.patch_inferencer = PatchInferencerEngine( + input_shape, output_shape, offsets, padding, weight_function, return_tuple + ) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + return self.forward(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Perform Inference in Patches + + Parameters + ---------- + x : torch.Tensor + Input Tensor. + """ + return self.patch_inferencer(self.model, x) + + def _single_step(self, batch: torch.Tensor, batch_idx: int, step_name: str): + """Perform a single step of the training/validation loop. + + Parameters + ---------- + batch : torch.Tensor + The input data. + batch_idx : int + The index of the batch. + step_name : str + The name of the step, either "train" or "val". + + Returns + ------- + torch.Tensor + The loss value. + """ + x, y = batch + y_hat = self.forward(x.float()) + loss = self.model._loss_func(y_hat, y.squeeze(1)) + + metrics = self.model._compute_metrics(y_hat, y, step_name) + for metric_name, metric_value in metrics.items(): + self.log( + metric_name, + metric_value, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + + self.log( + f"{step_name}_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + + return loss + + def training_step(self, batch: torch.Tensor, batch_idx: int): + return self._single_step(batch, batch_idx, "train") + + def validation_step(self, batch: torch.Tensor, batch_idx: int): + return self._single_step(batch, batch_idx, "val") + + def test_step(self, batch: torch.Tensor, batch_idx: int): + return self._single_step(batch, batch_idx, "test") + + +# region _PatchInferencer +class PatchInferencerEngine(_Engine): + + def __init__( + self, + input_shape: Tuple[int], + output_shape: Optional[Tuple[int]] = None, + offsets: Optional[List[Tuple]] = None, + padding: Optional[Dict[str, Any]] = None, + weight_function: Optional[Callable] = None, + return_tuple: Optional[int] = None, + ): + """ + Parameters + ---------- + model : nn.Module + The neural network model for inference. + input_shape : Tuple[int] + Shape of each patch to process. + output_shape : Tuple[int], optional + Expected shape of the model output per patch. Defaults to input_shape. + padding : dict, optional + Padding configuration with keys: + - 'pad': Tuple of padding for each expected final dimension, e.g., (0, 512, 512) - (c, h, w). + - 'mode': Padding mode, e.g., 'constant', 'reflect'. + - 'value': Padding value if mode is 'constant'. + """ + self.input_shape = (1, *input_shape) + self.output_shape = ( + (1, *output_shape) if output_shape is not None else self.input_shape + ) + self.weight_function = weight_function if offsets is not None: @@ -59,44 +168,34 @@ def __init__( padding["pad"] ), f"Pad tuple does not match expected size ({len(input_shape)})" self.padding = padding + self.padding["pad"] = (0, *self.padding["pad"]) else: - self.padding = {"pad": tuple([0] * len(input_shape))} - - def __call__(self, x: torch.Tensor) -> torch.Tensor: - return self.forward(x) + self.padding = {"pad": tuple([0] * (len(input_shape) + 1))} + self.return_tuple = return_tuple def _reconstruct_patches( self, patches: torch.Tensor, index: Tuple[int], - weights: bool, - inner_dim: int = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Rearranges patches to reconstruct area of interest from patches and weights """ reconstruct_shape = np.array(self.output_shape) * np.array(index) - if weights: - weight = torch.zeros(tuple(reconstruct_shape)) - base_weight = ( - self.weight_function(self.input_shape) - if self.weight_function - else torch.ones(self.input_shape) - ) - else: - weight = None - if inner_dim is not None: - reconstruct_shape = np.append(reconstruct_shape, inner_dim) - reconstruct = torch.zeros(tuple(reconstruct_shape)) + weight = torch.zeros(tuple(reconstruct_shape), device=patches.device) + base_weight = ( + self.weight_function(self.output_shape) + if self.weight_function + else torch.ones(self.output_shape, device=patches.device) + ) + + reconstruct = torch.zeros(tuple(reconstruct_shape), device=patches.device) for patch_index, patch in zip(np.ndindex(index), patches): sl = [ slice(idx * patch_len, (idx + 1) * patch_len, None) - for idx, patch_len in zip(patch_index, self.input_shape) + for idx, patch_len in zip(patch_index, self.output_shape) ] - if weights: - weight[tuple(sl)] = base_weight - if inner_dim is not None: - sl.append(slice(None, None, None)) + weight[tuple(sl)] = base_weight reconstruct[tuple(sl)] = patch return reconstruct, weight @@ -110,30 +209,20 @@ def _adjust_patches( """ Pads reconstructed_patches with 'pad_value' to have same shape as the reference shape from the base patch set """ - has_inner_dim = len(offset) < len(arrays[0].shape) pad_width = [] sl = [] ref_shape = list(ref_shape) arr_shape = list(arrays[0].shape) - if has_inner_dim: - arr_shape = arr_shape[:-1] - for idx, lenght, ref in zip(offset, arr_shape, ref_shape): + for idx, length, ref in zip([0, *offset], arr_shape, ref_shape): if idx > 0: - sl.append(slice(0, min(lenght, ref), None)) - pad_width = [idx, max(ref - lenght - idx, 0)] + pad_width + sl.append(slice(0, min(length, ref), None)) + pad_width = [idx, max(ref - length - idx, 0)] + pad_width else: - sl.append(slice(np.abs(idx), min(lenght, ref - idx), None)) - pad_width = [0, max(ref - lenght - idx, 0)] + pad_width + sl.append(slice(np.abs(idx), min(length, ref - idx), None)) + pad_width = [0, max(ref - length - idx, 0)] + pad_width adjusted = [ ( torch.nn.functional.pad( - arr[tuple([*sl, slice(None, None, None)])], - pad=tuple([0, 0, *pad_width]), - mode="constant", - value=pad_value, - ) - if has_inner_dim - else torch.nn.functional.pad( arr[tuple(sl)], pad=tuple(pad_width), mode="constant", @@ -151,11 +240,20 @@ def _combine_patches( indexes: List[Tuple[int]], ) -> torch.Tensor: """ - How results are combined is dependent on what is being combined. - RegressionPatchInferencer uses Weighted Average - ClassificationPatchInferencer uses Voting (hard or soft) + Combination of results """ - raise NotImplementedError("Combine patches method must be implemented") + reconstructed = [] + weights = [] + for patches, offset, shape in zip(results, offsets, indexes): + reconstruct, weight = self._reconstruct_patches(patches, shape) + reconstruct, weight = self._adjust_patches( + [reconstruct, weight], self.ref_shape, offset + ) + reconstructed.append(reconstruct) + weights.append(weight) + reconstructed = torch.stack(reconstructed, dim=0) + weights = torch.stack(weights, dim=0) + return torch.sum(reconstructed * weights, dim=0) / torch.sum(weights, dim=0) def _extract_patches( self, data: torch.Tensor, patch_shape: Tuple[int] @@ -187,7 +285,18 @@ def _compute_output_shape(self, tensor: torch.Tensor) -> Tuple[int]: shape.append(t) return tuple(shape) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def _compute_base_padding(self, tensor: torch.Tensor): + """ + Computes the padding for the base patch set based on the input tensor shape and the model's input shape. + """ + padding = [0, 0] + for i, t in zip(self.padding["pad"][2:], tensor.shape[1:]): + padding.append(max(0, i - t)) + return padding + + def __call__( + self, model: Union[L.LightningModule, torch.nn.Module], x: torch.Tensor + ): """ Perform Inference in Patches @@ -196,20 +305,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x : torch.Tensor Input Tensor. """ - assert len(x.shape) == len( - self.input_shape - ), "Input and self.input_shape sizes must match" + if len(x.shape) == len(self.input_shape) - 1: + x = x.unsqueeze(0) + elif len(x.shape) == len(self.input_shape): + pass + else: + raise RuntimeError("Invalid input shape") self.ref_shape = self._compute_output_shape(x) offsets = list(self.offsets) - base = self.padding["pad"] - offsets.insert(0, tuple([0] * len(base))) - + base = self._compute_base_padding(x) + offsets.insert(0, tuple([0] * (len(base) - 1))) slices = [ tuple( [ - slice(i + base, None) # TODO: if ((i + base >= 0) and (i < in_dim)) - for i, base, in_dim in zip(offset, base, x.shape) + slice(i, None) # TODO: if ((i + base >= 0) and (i < in_dim)) + for i, in_dim in zip([0, *offset], x.shape) ] ) for offset in offsets @@ -217,153 +328,36 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: torch_pad = [] for pad_value in reversed(base): - torch_pad = torch_pad + [pad_value, pad_value] + torch_pad = torch_pad + [0, pad_value] x_padded = torch.nn.functional.pad( x, pad=tuple(torch_pad), mode=self.padding.get("mode", "constant"), value=self.padding.get("value", 0), ) - results = [] + results = ( + tuple([] for _ in range(self.return_tuple)) if self.return_tuple else [] + ) indexes = [] for sl in slices: patch_set, patch_idx = self._extract_patches(x_padded[sl], self.input_shape) - results.append(self.model(patch_set)) + patch_set = patch_set.squeeze(1) + inference = model(patch_set) + if self.return_tuple: + for i in range(self.return_tuple): + results[i].append(inference[i]) + else: + results.append(inference) indexes.append(patch_idx) - output_slice = tuple( - [slice(0, lenght) for lenght in x.shape] - ) - return self._combine_patches(results, offsets, indexes)[output_slice] - - -class WeightedAvgPatchInferencer(BasePatchInferencer): - """ - PatchInferencer with Weighted Average combination function. - """ - - def _combine_patches( - self, - results: List[torch.Tensor], - offsets: List[Tuple[int]], - indexes: List[Tuple[int]], - ) -> torch.Tensor: - reconstructed = [] - weights = [] - for patches, offset, shape in zip(results, offsets, indexes): - reconstruct, weight = self._reconstruct_patches( - patches, shape, weights=True - ) - reconstruct, weight = self._adjust_patches( - [reconstruct, weight], self.ref_shape, offset - ) - - reconstructed.append(reconstruct) - weights.append(weight) - reconstructed = torch.stack(reconstructed, dim=0) - weights = torch.stack(weights, dim=0) - return torch.sum(reconstructed * weights, dim=0) / torch.sum(weights, dim=0) - - -class VotingPatchInferencer(BasePatchInferencer): - """ - PatchInferencer with Voting combination function. - Note: Models used with VotingPatchInferencer must return class probabilities in inner dimension - """ - - def __init__( - self, - model: L.LightningModule, - num_classes: int, - input_shape: Tuple, - output_shape: Optional[Tuple] = None, - weight_function: Optional[callable] = None, - offsets: Optional[List[Tuple]] = None, - padding: Optional[Dict[str, Any]] = None, - voting: str = "soft", - ): - """Initialize the patch inference auxiliary class - - Parameters - ---------- - model : L.LightningModule - Model used in inference. - num_classes: int - number of classes of the classification task - input_shape : Tuple - Expected input shape of the model - output_shape : Tuple, optional - Expected output shape of the model. Defaults to input_shape - weight_function: callable, optional - Function that receives a tensor shape and returns the weights for each position of a tensor with the given shape - Useful when regions of the inference present diminishing performance when getting closer to borders, for instance. - offsets : Tuple, optional - List of tuples with offsets that determine the shift of the initial position of the patch subdivision - padding : Dict[str, Any], optional - Dictionary describing padding strategy. Keys: - pad: tuple with pad width (int) for each dimension, e.g. (0, 3, 3) when working with a tensor with 3 dimensions - mode (optional): 'constant', 'reflect', 'replicate' or 'cicular'. Defaults to 'constant'. - value (optional): fill value for 'constante'. Defaults to 0. - voting: str - voting method to use, can be either 'soft'or 'hard'. Defaults to 'soft'. - """ - super().__init__( - model, input_shape, output_shape, weight_function, offsets, padding - ) - assert voting in ["soft", "hard"], "voting should be either 'soft' or 'hard'" - self.num_classes = num_classes - self.voting = voting - - def _combine_patches( - self, - results: List[torch.Tensor], - offsets: List[Tuple[int]], - indexes: List[Tuple[int]], - ) -> torch.Tensor: - voting_method = getattr(self, f"_{self.voting}_voting") - return voting_method(results, offsets, indexes) - - def _hard_voting( - self, - results: List[torch.Tensor], - offsets: List[Tuple[int]], - indexes: List[Tuple[int]], - ) -> torch.Tensor: - """ - Hard voting combination function - """ - # torch.mode does not work like scipy.stats.mode - raise NotImplementedError("Hard voting not yet supported") - # reconstructed = [] - # for patches, offset, shape in zip(results, offsets, indexes): - # reconstruct, _ = self._reconstruct_patches( - # patches, shape, weights=False, inner_dim=self.num_classes - # ) - # reconstruct = torch.argmax(reconstruct, dim=-1).float() - # reconstruct = self._adjust_patches( - # [reconstruct], self.ref_shape, offset, pad_value=torch.nan - # )[0] - # reconstructed.append(reconstruct) - # reconstructed = torch.stack(reconstructed, dim=0) - # ret = torch.mode(reconstructed, dim=0, keepdims=False)[ - # 0 - # ] # TODO check behaviour on GPU, according to issues may have nonsense results - # return ret - - def _soft_voting( - self, - results: List[torch.Tensor], - offsets: List[Tuple[int]], - indexes: List[Tuple[int]], - ) -> torch.Tensor: - """ - Soft voting combination function - """ - reconstructed = [] - for patches, offset, shape in zip(results, offsets, indexes): - reconstruct, _ = self._reconstruct_patches( - patches, shape, weights=False, inner_dim=self.num_classes - ) - reconstruct = self._adjust_patches([reconstruct], self.ref_shape, offset)[0] - reconstructed.append(reconstruct) - reconstructed = torch.stack(reconstructed, dim=0) - return torch.argmax(torch.sum(reconstructed, dim=0), dim=-1) + output_slice = tuple([slice(0, length) for length in self.ref_shape]) + if self.return_tuple: + comb_list = [] + for i in range(self.return_tuple): + comb = self._combine_patches(results[i], offsets, indexes) + comb = comb[output_slice] + comb_list.append(comb) + comb = tuple(comb_list) + else: + comb = self._combine_patches(results, offsets, indexes) + comb = comb[output_slice] + return comb diff --git a/minerva/models/nets/image/setr.py b/minerva/models/nets/image/setr.py index 094b4ec..9a39413 100644 --- a/minerva/models/nets/image/setr.py +++ b/minerva/models/nets/image/setr.py @@ -1,15 +1,18 @@ import warnings from typing import Dict, List, Optional, Tuple, Union -import lightning as L +import lightning.pytorch as L import torch from torch import nn +from torch.optim.adam import Adam from torchmetrics import Metric +from minerva.engines.engine import _Engine from minerva.models.nets.image.vit import _VisionTransformerBackbone from minerva.utils.upsample import Upsample +# region _SETRUPHead class _SETRUPHead(nn.Module): """Naive upsampling head and Progressive upsampling head of SETR. @@ -119,9 +122,10 @@ def forward(self, x): return out +# region _SETRMLAHead class _SETRMLAHead(nn.Module): """Multi level feature aggretation head of SETR. - + This has not been tested yet. MLA head of `SETR `_. """ @@ -218,6 +222,7 @@ def forward(self, x): return out +# region _SetR_PUP class _SetR_PUP(nn.Module): def __init__( @@ -240,8 +245,9 @@ def __init__( conv_norm: nn.Module, conv_act: nn.Module, align_corners: bool, - aux_output: bool = False, - aux_output_layers: Optional[List[int]] = None, + aux_output: bool, + aux_output_layers: list[int] | None, + original_resolution: Optional[Tuple[int, int]], ): """ Initializes the SETR PUP model. @@ -307,6 +313,7 @@ def __init__( dropout=encoder_dropout, aux_output=aux_output, aux_output_layers=aux_output_layers, + original_resolution=original_resolution, ) self.decoder = _SETRUPHead( @@ -383,8 +390,108 @@ def forward(self, x: torch.Tensor): x = self.decoder(x) return x + def load_backbone(self, path: str, freeze: bool = False): + self.encoder.load_backbone(path) + if freeze: + for param in self.encoder.parameters(): + param.requires_grad = False + +# region SETR_PUP class SETR_PUP(L.LightningModule): + """ + SETR_PUP is a PyTorch Lightning Module for the SETR (Segmenter Transformer) model with Patch Up-sampling (PUP). + + Parameters + ---------- + image_size : Union[int, Tuple[int, int]], default=512 + The size of the input image. + patch_size : int, default=16 + The size of the patches to be extracted from the input image. + num_layers : int, default=24 + The number of transformer layers in the encoder. + num_heads : int, default=16 + The number of attention heads in each transformer layer. + hidden_dim : int, default=1024 + The hidden dimension of the transformer layers. + mlp_dim : int, default=4096 + The dimension of the MLP (Feed-Forward) layers in the transformer. + encoder_dropout : float, default=0.1 + The dropout rate for the encoder. + num_classes : int, default=1000 + The number of output classes. + norm_layer : Optional[nn.Module], default=None + The normalization layer to be used in the transformer. + decoder_channels : int, default=256 + The number of channels in the decoder. + num_convs : int, default=4 + The number of convolutional layers in the decoder. + up_scale : int, default=2 + The up-sampling scale factor. + kernel_size : int, default=3 + The kernel size for the convolutional layers. + align_corners : bool, default=False + Whether to align corners when interpolating. + decoder_dropout : float, default=0.1 + The dropout rate for the decoder. + conv_norm : Optional[nn.Module], default=None + The normalization layer to be used in the convolutional layers. + conv_act : Optional[nn.Module], default=None + The activation function to be used in the convolutional layers. + interpolate_mode : str, default="bilinear" + The interpolation mode to be used for up-sampling. + loss_fn : Optional[nn.Module], default=None + The loss function to be used. + optimizer_type : Optional[type], default=None + The type of optimizer to be used. + optimizer_params : Optional[Dict], default=None + The parameters for the optimizer. + train_metrics : Optional[Dict[str, Metric]], default=None + The metrics to be used during training. + val_metrics : Optional[Dict[str, Metric]], default=None + The metrics to be used during validation. + test_metrics : Optional[Dict[str, Metric]], default=None + The metrics to be used during testing. + aux_output : bool, default=True + Whether to use auxiliary outputs. + aux_output_layers : list[int] | None, default=[9, 14, 19] + The layers from which to take auxiliary outputs. + aux_weights : list[float], default=[0.3, 0.3, 0.3] + The weights for the auxiliary outputs. + load_backbone_path : Optional[str], default=None + The path to the pre-trained backbone to be loaded. + freeze_backbone_on_load : bool, default=True + Whether to freeze the backbone after loading. + learning_rate : float, default=1e-3 + The learning rate for the optimizer. + loss_weights : Optional[list[float]], default=None + The weights for the loss function. + + Methods + ------- + forward(x: torch.Tensor) -> torch.Tensor + Forward pass of the model. + _compute_metrics(y_hat: torch.Tensor, y: torch.Tensor, step_name: str) + Compute metrics for the given step. + _loss_func(y_hat: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], y: torch.Tensor) -> torch.Tensor + Calculate the loss between the output and the input data. + _single_step(batch: torch.Tensor, batch_idx: int, step_name: str) + Perform a single step of the training/validation loop. + training_step(batch: torch.Tensor, batch_idx: int) + Perform a single training step. + validation_step(batch: torch.Tensor, batch_idx: int) + Perform a single validation step. + test_step(batch: torch.Tensor, batch_idx: int) + Perform a single test step. + predict_step(batch: torch.Tensor, batch_idx: int, dataloader_idx: Optional[int] = None) + Perform a single prediction step. + load_backbone(path: str, freeze: bool = False) + Load a pre-trained backbone. + configure_optimizers() + Configure the optimizer for the model. + create_from_dict(config: Dict) -> "SETR_PUP" + Create an instance of SETR_PUP from a configuration dictionary. + """ def __init__( self, @@ -407,72 +514,104 @@ def __init__( conv_act: Optional[nn.Module] = None, interpolate_mode: str = "bilinear", loss_fn: Optional[nn.Module] = None, + optimizer_type: Optional[type] = None, + optimizer_params: Optional[Dict] = None, train_metrics: Optional[Dict[str, Metric]] = None, val_metrics: Optional[Dict[str, Metric]] = None, test_metrics: Optional[Dict[str, Metric]] = None, aux_output: bool = True, - aux_output_layers: Optional[List[int]] = [9, 14, 19], - aux_weights: List[float] = [0.3, 0.3, 0.3], + aux_output_layers: list[int] | None = [9, 14, 19], + aux_weights: list[float] = [0.3, 0.3, 0.3], + load_backbone_path: Optional[str] = None, + freeze_backbone_on_load: bool = True, + learning_rate: float = 1e-3, + loss_weights: Optional[list[float]] = None, + original_resolution: Optional[Tuple[int, int]] = None, + head_lr_factor: float = 1.0, + test_engine: Optional[_Engine] = None, ): """ - Initializes the SetR model. + Initialize the SETR model. Parameters ---------- - image_size : int or Tuple[int, int] - The input image size. Defaults to 512. - patch_size : int - The size of each patch. Defaults to 16. - num_layers : int - The number of layers in the transformer encoder. Defaults to 24. - num_heads : int - The number of attention heads in the transformer encoder. Defaults to 16. - hidden_dim : int - The hidden dimension of the transformer encoder. Defaults to 1024. - mlp_dim : int - The dimension of the MLP layers in the transformer encoder. Defaults to 4096. - encoder_dropout : float - The dropout rate for the transformer encoder. Defaults to 0.1. - num_classes : int - The number of output classes. Defaults to 1000. - norm_layer : nn.Module, optional - The normalization layer to be used in the decoder. Defaults to None. - decoder_channels : int - The number of channels in the decoder. Defaults to 256. - num_convs : int - The number of convolutional layers in the decoder. Defaults to 4. - up_scale : int - The scale factor for upsampling in the decoder. Defaults to 2. - kernel_size : int - The kernel size for convolutional layers in the decoder. Defaults to 3. - align_corners : bool - Whether to align corners during interpolation in the decoder. Defaults to False. - decoder_dropout : float - The dropout rate for the decoder. Defaults to 0.1. - conv_norm : nn.Module, optional - The normalization layer to be used in the convolutional layers of the decoder. Defaults to None. - conv_act : nn.Module, optional - The activation function to be used in the convolutional layers of the decoder. Defaults to None. - interpolate_mode : str - The interpolation mode for upsampling in the decoder. Defaults to "bilinear". - loss_fn : nn.Module, optional - The loss function to be used during training. Defaults to None. - train_metrics : Dict[str, Metric], optional - The metrics to be used for training evaluation. Defaults to None. - val_metrics : Dict[str, Metric], optional - The metrics to be used for validation evaluation. Defaults to None. - test_metrics : Dict[str, Metric], optional - The metrics to be used for testing evaluation. Defaults to None. - aux_output : bool - Whether to include auxiliary output heads in the model. Defaults to True. - aux_output_layers : List[int], optional - The indices of the layers to output auxiliary predictions. Defaults to [9, 14, 19]. - aux_weights : List[float] - The weights for the auxiliary predictions. Defaults to [0.3, 0.3, 0.3]. - + image_size : Union[int, Tuple[int, int]], optional + Size of the input image, by default 512. + patch_size : int, optional + Size of the patches to be extracted from the input image, by default 16. + num_layers : int, optional + Number of transformer layers, by default 24. + num_heads : int, optional + Number of attention heads, by default 16. + hidden_dim : int, optional + Dimension of the hidden layer, by default 1024. + mlp_dim : int, optional + Dimension of the MLP layer, by default 4096. + encoder_dropout : float, optional + Dropout rate for the encoder, by default 0.1. + num_classes : int, optional + Number of output classes, by default 1000. + norm_layer : Optional[nn.Module], optional + Normalization layer, by default None. + decoder_channels : int, optional + Number of channels in the decoder, by default 256. + num_convs : int, optional + Number of convolutional layers in the decoder, by default 4. + up_scale : int, optional + Upscaling factor for the decoder, by default 2. + kernel_size : int, optional + Kernel size for the convolutional layers, by default 3. + align_corners : bool, optional + Whether to align corners when interpolating, by default False. + decoder_dropout : float, optional + Dropout rate for the decoder, by default 0.1. + conv_norm : Optional[nn.Module], optional + Normalization layer for the convolutional layers, by default None. + conv_act : Optional[nn.Module], optional + Activation function for the convolutional layers, by default None. + interpolate_mode : str, optional + Interpolation mode, by default "bilinear". + loss_fn : Optional[nn.Module], optional + Loss function, by default None. + optimizer_type : Optional[type], optional + Type of optimizer, by default None. + optimizer_params : Optional[Dict], optional + Parameters for the optimizer, by default None. + train_metrics : Optional[Dict[str, Metric]], optional + Metrics for training, by default None. + val_metrics : Optional[Dict[str, Metric]], optional + Metrics for validation, by default None. + test_metrics : Optional[Dict[str, Metric]], optional + Metrics for testing, by default None. + aux_output : bool, optional + Whether to use auxiliary outputs, by default True. + aux_output_layers : list[int] | None, optional + Layers for auxiliary outputs, by default [9, 14, 19]. + aux_weights : list[float], optional + Weights for auxiliary outputs, by default [0.3, 0.3, 0.3]. + load_backbone_path : Optional[str], optional + Path to load the backbone model, by default None. + freeze_backbone_on_load : bool, optional + Whether to freeze the backbone model on load, by default True. + learning_rate : float, optional + Learning rate, by default 1e-3. + loss_weights : Optional[list[float]], optional + Weights for the loss function, by default None. """ + super().__init__() - self.loss_fn = loss_fn if loss_fn is not None else nn.CrossEntropyLoss() + + if head_lr_factor != 1: + self.automatic_optimization = False + self.multiple_optimizers = True + + self.loss_fn = ( + loss_fn + if loss_fn is not None + else nn.CrossEntropyLoss( + weight=torch.tensor(loss_weights) if loss_weights is not None else None + ) + ) norm_layer = norm_layer if norm_layer is not None else nn.LayerNorm(hidden_dim) conv_norm = ( conv_norm if conv_norm is not None else nn.SyncBatchNorm(decoder_channels) @@ -488,8 +627,14 @@ def __init__( aux_output_layers ), "aux_weights must have the same length as aux_output_layers." + self.optimizer_type = optimizer_type + if optimizer_type is not None: + assert optimizer_params is not None, "optimizer_params must be provided." + self.optimizer_params = optimizer_params + self.num_classes = num_classes self.aux_weights = aux_weights + self.head_lr_factor = head_lr_factor self.metrics = { "train": train_metrics, @@ -518,7 +663,13 @@ def __init__( align_corners=align_corners, aux_output=aux_output, aux_output_layers=aux_output_layers, + original_resolution=original_resolution, ) + if load_backbone_path is not None: + self.model.load_backbone(load_backbone_path, freeze_backbone_on_load) + + self.learning_rate = learning_rate + self.test_engine = test_engine def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x) @@ -589,10 +740,14 @@ def _single_step(self, batch: torch.Tensor, batch_idx: int, step_name: str): The loss value. """ x, y = batch - y_hat = self.model(x.float()) - loss = self._loss_func(y_hat[0], y.squeeze(1)) + if self.test_engine and (step_name == "test" or step_name == "val"): + y_hat = self.test_engine(self.model, x) + else: + y_hat = self.model(x) metrics = self._compute_metrics(y_hat[0], y, step_name) + loss = self._loss_func(y_hat, y.squeeze(1)) + for metric_name, metric_value in metrics.items(): self.log( metric_name, @@ -617,7 +772,20 @@ def _single_step(self, batch: torch.Tensor, batch_idx: int, step_name: str): return loss def training_step(self, batch: torch.Tensor, batch_idx: int): - return self._single_step(batch, batch_idx, "train") + if self.multiple_optimizers: + optimizers_list = self.optimizers() + + for opt in optimizers_list: + opt.zero_grad() + + loss = self._single_step(batch, batch_idx, "train") + + self.manual_backward(loss) + + for opt in optimizers_list: + opt.step() + else: + return self._single_step(batch, batch_idx, "train") def validation_step(self, batch: torch.Tensor, batch_idx: int): return self._single_step(batch, batch_idx, "val") @@ -631,5 +799,44 @@ def predict_step( x, _ = batch return self.model(x)[0] + def load_backbone(self, path: str, freeze: bool = False): + self.model.load_backbone(path, freeze) + def configure_optimizers(self): - return torch.optim.Adam(self.model.parameters(), lr=1e-3) + if self.multiple_optimizers: + return ( + [ + self.optimizer_type( + self.model.encoder.parameters(), + lr=self.learning_rate, + **self.optimizer_params, + ), + self.optimizer_type( + list(self.model.decoder.parameters()) + + list(self.model.aux_head1.parameters()) + + list(self.model.aux_head2.parameters()) + + list(self.model.aux_head3.parameters()), + lr=self.learning_rate * self.head_lr_factor, + **self.optimizer_params, + ), + ] + if self.optimizer_type is not None + else [ + Adam(self.model.encoder.parameters(), lr=self.learning_rate), + Adam(self.model.decoder.parameters(), lr=self.learning_rate), + ] + ) + else: + return ( + self.optimizer_type( + self.model.parameters(), + lr=self.learning_rate, + **self.optimizer_params, + ) + if self.optimizer_type is not None + else Adam(self.model.parameters(), lr=self.learning_rate) + ) + + @staticmethod + def create_from_dict(config: Dict) -> "SETR_PUP": + return SETR_PUP(**config) diff --git a/minerva/models/nets/image/vit.py b/minerva/models/nets/image/vit.py index 238c63f..ff33a5d 100644 --- a/minerva/models/nets/image/vit.py +++ b/minerva/models/nets/image/vit.py @@ -8,6 +8,7 @@ import torch.nn as nn from timm.models.vision_transformer import Block, PatchEmbed from torch import nn +from torch.nn import functional as F from torchvision.models.vision_transformer import ( Conv2dNormActivation, ConvStemConfig, @@ -87,6 +88,7 @@ def __init__( num_heads: int, hidden_dim: int, mlp_dim: int, + original_resolution: Optional[Tuple[int, int]] = None, dropout: float = 0.0, attention_dropout: float = 0.0, num_classes: int = 1000, @@ -113,6 +115,8 @@ def __init__( The dimensionality of the hidden layers in the transformer. mlp_dim : int The dimensionality of the feed-forward MLP layers in the transformer. + original_resolution : Tuple[int, int], optional + The original resolution of the input image in the pre-training weights. When None, positional embeddings will not be interpolated. Defaults to None. dropout : float, optional The dropout rate to apply. Defaults to 0.0. attention_dropout : float, optional @@ -156,6 +160,9 @@ def __init__( self.norm_layer = norm_layer self.aux_output = aux_output self.aux_output_layers = aux_output_layers + self.original_resolution = ( + original_resolution if original_resolution else image_size + ) if conv_stem_configs is not None: # As per https://arxiv.org/abs/2106.14881 @@ -284,6 +291,93 @@ def _process_input(self, x: torch.Tensor) -> Tuple[torch.Tensor, int, int]: return x, n_h, n_w + def interpolate_pos_embeddings(self, pretrained_pos_embed, new_img_size): + """Interpolate encoder's positional embeddings to fit a new input size. + + Args: + pretrained_pos_embed (torch.Tensor): Pretrained positional embeddings. + new_img_size (Tuple[int, int]): New height and width of the input image. + """ + h, w = new_img_size[0] // self.patch_size, new_img_size[1] // self.patch_size + new_grid_size = (h, w) + + # Reshape pretrained positional embeddings to match the original grid size + + original_resolution = ( + self.original_resolution + if isinstance(self.original_resolution, Tuple) + else (self.original_resolution, self.original_resolution) + ) + + pos_embed_reshaped = pretrained_pos_embed[:, 1:].reshape( + 1, + original_resolution[0] // self.patch_size, + original_resolution[1] // self.patch_size, + -1, + ) + + # Interpolate positional embeddings to the new grid size + pos_embed_interpolated = ( + F.interpolate( + pos_embed_reshaped.permute( + 0, 3, 1, 2 + ), # (1, C, H, W) for interpolation + size=new_grid_size, + mode="bilinear", + align_corners=False, + ) + .permute(0, 2, 3, 1) + .reshape(1, -1, pos_embed_reshaped.shape[-1]) + ) + + # Concatenate the CLS token and the interpolated positional embeddings + cls_token = pretrained_pos_embed[:, :1] + pos_embed_interpolated = torch.cat((cls_token, pos_embed_interpolated), dim=1) + + return pos_embed_interpolated + + return pos_embed_interpolated + + def load_backbone(self, path: str, freeze: bool = False): + """Loads pretrained weights and handles positional embedding resizing if necessary.""" + # Load the pretrained state dict + state_dict = torch.load(path) + + # Expected shape for positional embeddings based on current model image size + + image_size = ( + self.image_size + if isinstance(self.image_size, Tuple) + else (self.image_size, self.image_size) + ) + + expected_pos_embed_shape = ( + 1, + (image_size[0] // self.patch_size) * (image_size[1] // self.patch_size) + 1, + self.hidden_dim, + ) + + # Check if positional embeddings need interpolation + if state_dict["encoder.pos_embedding"].shape != expected_pos_embed_shape: + # Extract the positional embeddings from the state dict + pretrained_pos_embed = state_dict["encoder.pos_embedding"] + + # Interpolate to match the current image size + print("Interpolating positional embeddings to match the new image size.") + with torch.no_grad(): + pos_embed_interpolated = self.interpolate_pos_embeddings( + pretrained_pos_embed, (image_size[0], image_size[1]) + ) + state_dict["encoder.pos_embedding"] = pos_embed_interpolated + + # Load the (potentially modified) state dict into the encoder + self.encoder.load_state_dict(state_dict, strict=False) + + # Optionally freeze parameters + if freeze: + for param in self.encoder.parameters(): + param.requires_grad = False + def forward(self, x: torch.Tensor): """Forward pass of the Vision Transformer Backbone. @@ -293,6 +387,7 @@ def forward(self, x: torch.Tensor): Returns: torch.Tensor: The output tensor. """ + # Reshape and permute the input tensor x, n_h, n_w = self._process_input(x) n = x.shape[0] @@ -328,6 +423,45 @@ def forward(self, x: torch.Tensor): return x + def load_weights(self, weights_path: str, freeze: bool = False): + + state_dict = torch.load(weights_path) + + # Get expected positional embedding shape based on current image size + + image_size = ( + self.image_size + if isinstance(self.image_size, Tuple) + else (self.image_size, self.image_size) + ) + + expected_pos_embed_shape = ( + 1, + (image_size[0] // self.patch_size) * (image_size[1] // self.patch_size) + 1, + self.hidden_dim, + ) + + # Check if positional embeddings need interpolation + if state_dict["encoder.pos_embedding"].shape != expected_pos_embed_shape: + # Extract the positional embeddings from the state dict + pretrained_pos_embed = state_dict["encoder.pos_embedding"] + + # Interpolate to match the current image size + print("Interpolating positional embeddings to match the new image size.") + with torch.no_grad(): + pos_embed_interpolated = self.interpolate_pos_embeddings( + pretrained_pos_embed, (image_size[0], image_size[1]) + ) + state_dict["encoder.pos_embedding"] = pos_embed_interpolated + + # Load the (potentially modified) state dict + self.load_state_dict(state_dict, strict=False) + + # Optionally freeze parameters + if freeze: + for param in self.parameters(): + param.requires_grad = False + class MaskedAutoencoderViT(L.LightningModule): """ diff --git a/minerva/pipelines/base.py b/minerva/pipelines/base.py index f626a3f..88167d7 100644 --- a/minerva/pipelines/base.py +++ b/minerva/pipelines/base.py @@ -159,11 +159,11 @@ def pipeline_info(self) -> Dict[str, str]: The dictionary with the pipeline information """ return { - "class_name": self.__class__.__name__, - "created_time": self._created_at, + "class_name": str(self.__class__.__name__), + "created_time": str(self._created_at), "pipeline_id": self.pipeline_id, "log_dir": str(self.log_dir), - "run_count": self._run_count, + "run_count": str(self._run_count), } @property diff --git a/minerva/pipelines/hyperopt_hyperparameter_search.py b/minerva/pipelines/hyperopt_hyperparameter_search.py new file mode 100644 index 0000000..0cafb0c --- /dev/null +++ b/minerva/pipelines/hyperopt_hyperparameter_search.py @@ -0,0 +1,154 @@ +from copy import deepcopy +from typing import Any, Dict, Literal, Optional + +import lightning.pytorch as L +from lightning.pytorch.strategies import Strategy +from ray import tune +from ray.train import CheckpointConfig, RunConfig, ScalingConfig +from ray.train.lightning import RayDDPStrategy, RayLightningEnvironment, prepare_trainer +from ray.train.torch import TorchTrainer +from ray.tune.schedulers import ASHAScheduler, TrialScheduler +from ray.tune.search import ConcurrencyLimiter +from ray.tune.search.hyperopt import HyperOptSearch +from ray.tune.stopper import TrialPlateauStopper + +from minerva.callbacks.HyperSearchCallbacks import TrainerReportOnIntervalCallback +from minerva.pipelines.base import Pipeline +from minerva.utils.typing import PathLike + + +class HyperoptHyperParameterSearch(Pipeline): + + def __init__( + self, + model: type, + search_space: Dict[str, Any], + log_dir: Optional[PathLike] = None, + save_run_status: bool = True, + ): + super().__init__(log_dir=log_dir, save_run_status=save_run_status) + self.model = model + self.search_space = search_space + + def _search( + self, + data: L.LightningDataModule, + ckpt_path: Optional[PathLike], + devices: Optional[str] = "auto", + accelerator: Optional[str] = "auto", + strategy: Optional[Strategy] = None, + callbacks: Optional[Any] = None, + plugins: Optional[Any] = None, + num_nodes: int = 1, + debug_mode: Optional[bool] = False, + scaling_config: Optional[ScalingConfig] = None, + run_config: Optional[RunConfig] = None, + tuner_metric: Optional[str] = "val_loss", + tuner_mode: Optional[str] = "min", + num_samples: Optional[int] = -1, + scheduler: Optional[TrialScheduler] = None, + max_concurrent: Optional[int] = 4, + initial_parameters: Optional[Dict[str, Any]] = None, + max_epochs: Optional[int] = None, + num_results: Optional[int] = 5, + std: Optional[float] = 0.01, + grace_period: Optional[int] = 50, + ) -> Any: + + def _tuner_train_func(config): + dm = deepcopy(data) + model = self.model.create_from_dict(config) + trainer = L.Trainer( + max_epochs=max_epochs or 500, + devices=devices or "auto", + accelerator=accelerator or "auto", + strategy=strategy or RayDDPStrategy(find_unused_parameters=True), + callbacks=callbacks or [TrainerReportOnIntervalCallback(500)], + plugins=plugins or [RayLightningEnvironment()], + enable_progress_bar=False, + num_nodes=num_nodes, + enable_checkpointing=False if debug_mode else None, + ) + trainer = prepare_trainer(trainer) + trainer.fit(model, dm, ckpt_path=ckpt_path) + + scheduler = scheduler or ASHAScheduler( + time_attr="training_iteration", + metric=tuner_metric or "val_loss", + mode=tuner_mode or "min", + max_t=500, + grace_period=100, + ) + + scaling_config = scaling_config or ScalingConfig( + num_workers=1, use_gpu=True, resources_per_worker={"GPU": 1} + ) + + run_config = run_config or RunConfig( + checkpoint_config=CheckpointConfig( + num_to_keep=1, + checkpoint_score_attribute="val_loss", + checkpoint_score_order="min", + ), + stop=TrialPlateauStopper( + metric=tuner_metric or "val_loss", + mode=tuner_mode or "min", + num_results=num_results or 5, + std=std or 0.01, + grace_period=grace_period or 50, + ), + ) + + ray_trainer = TorchTrainer( + _tuner_train_func, + scaling_config=scaling_config, + run_config=run_config, + ) + + algo = ConcurrencyLimiter( + HyperOptSearch(initial_parameters), max_concurrent=max_concurrent or 4 + ) + + tuner = tune.Tuner( + ray_trainer, + param_space={"train_loop_config": self.search_space}, + tune_config=tune.TuneConfig( + metric=tuner_metric or "val_loss", + mode=tuner_mode or "min", + num_samples=num_samples or -1, + search_alg=algo, + ), + ) + return tuner.fit() + + def _test(self, data: L.LightningDataModule, ckpt_path: Optional[PathLike]) -> Any: + # TODO fix this + return self.trainer.test(self.model, data, ckpt_path=ckpt_path) + + def _run( + self, + data: L.LightningDataModule, + task: Optional[Literal["search", "test", "predict"]], + ckpt_path: Optional[PathLike] = None, + config: Dict[str, Any] = {}, + **kwargs, + ) -> Any: + if task == "search": + return self._search(data, ckpt_path, **config) + elif task == "test": + return self._test(data, ckpt_path) + elif task is None: + search = self._search(data, ckpt_path, **config) + test = self._test(data, ckpt_path) + return search, test + + +def main(): + from jsonargparse import CLI + + print("Hyper Searching 🔍") + CLI(HyperoptHyperParameterSearch, as_positional=False) + + +if __name__ == "__main__": + main() diff --git a/minerva/pipelines/ray_hyperparameter_search.py b/minerva/pipelines/ray_hyperparameter_search.py new file mode 100644 index 0000000..c8b96c2 --- /dev/null +++ b/minerva/pipelines/ray_hyperparameter_search.py @@ -0,0 +1,133 @@ +from copy import deepcopy +from typing import Any, Dict, Literal, Optional + +import lightning.pytorch as L +from lightning.pytorch.strategies import Strategy +from ray import tune +from ray.train import CheckpointConfig, RunConfig, ScalingConfig +from ray.train.lightning import RayDDPStrategy, RayLightningEnvironment, prepare_trainer +from ray.train.torch import TorchTrainer +from ray.tune.schedulers import ASHAScheduler, TrialScheduler + +from minerva.callbacks.HyperSearchCallbacks import TrainerReportKeepOnlyLastCallback +from minerva.pipelines.base import Pipeline +from minerva.utils.typing import PathLike + + +class RayHyperParameterSearch(Pipeline): + + def __init__( + self, + model: type, + search_space: Dict[str, Any], + log_dir: Optional[PathLike] = None, + save_run_status: bool = True, + ): + super().__init__(log_dir=log_dir, save_run_status=save_run_status) + self.model = model + self.search_space = search_space + + def _search( + self, + data: L.LightningDataModule, + ckpt_path: Optional[PathLike], + devices: Optional[str] = "auto", + accelerator: Optional[str] = "auto", + strategy: Optional[Strategy] = None, + callbacks: Optional[Any] = None, + plugins: Optional[Any] = None, + num_nodes: int = 1, + debug_mode: Optional[bool] = False, + scaling_config: Optional[ScalingConfig] = None, + run_config: Optional[RunConfig] = None, + tuner_metric: Optional[str] = "val_loss", + tuner_mode: Optional[str] = "min", + num_samples: Optional[int] = 10, + scheduler: Optional[TrialScheduler] = None, + ) -> Any: + + def _tuner_train_func(config): + dm = deepcopy(data) + model = self.model.create_from_dict(config_dict=config) + trainer = L.Trainer( + devices=devices or "auto", + accelerator=accelerator or "auto", + strategy=strategy or RayDDPStrategy(find_unused_parameters=True), + callbacks=callbacks or [TrainerReportKeepOnlyLastCallback()], + plugins=plugins or [RayLightningEnvironment()], + enable_progress_bar=False, + num_nodes=num_nodes, + enable_checkpointing=False if debug_mode else None, + ) + trainer = prepare_trainer(trainer) + trainer.fit(model, dm, ckpt_path=ckpt_path) + + scheduler = scheduler or ASHAScheduler( + time_attr="training_iteration", + metric=tuner_metric or "val_loss", + mode=tuner_mode or "min", + max_t=2, + grace_period=1, + brackets=1, + ) + + scaling_config = scaling_config or ScalingConfig( + num_workers=1, use_gpu=True, resources_per_worker={"GPU": 1} + ) + + run_config = run_config or RunConfig( + checkpoint_config=CheckpointConfig( + num_to_keep=1, + checkpoint_score_attribute="val_loss", + checkpoint_score_order="min", + checkpoint_frequency=10, + ) + ) + + ray_trainer = TorchTrainer( + _tuner_train_func, + scaling_config=scaling_config, + run_config=run_config, + ) + tuner = tune.Tuner( + ray_trainer, + param_space={"train_loop_config": self.search_space}, + tune_config=tune.TuneConfig( + metric=tuner_metric or "val_loss", + mode=tuner_mode or "min", + num_samples=num_samples or 10, + scheduler=scheduler, + ), + ) + return tuner.fit() + + def _test(self, data: L.LightningDataModule, ckpt_path: Optional[PathLike]) -> Any: + # TODO fix this + return self.trainer.test(self.model, data, ckpt_path=ckpt_path) + + def _run( + self, + data: L.LightningDataModule, + task: Optional[Literal["search", "test", "predict"]], + ckpt_path: Optional[PathLike] = None, + **kwargs, + ) -> Any: + if task == "search": + return self._search(data, ckpt_path, **kwargs) + elif task == "test": + return self._test(data, ckpt_path) + elif task is None: + search = self._search(data, ckpt_path, **kwargs) + test = self._test(data, ckpt_path) + return search, test + + +def main(): + from jsonargparse import CLI + + print("Hyper Searching 🔍") + CLI(RayHyperParameterSearch, as_positional=False) + + +if __name__ == "__main__": + main() diff --git a/minerva/transforms/context_transform.py b/minerva/transforms/context_transform.py new file mode 100644 index 0000000..8c09564 --- /dev/null +++ b/minerva/transforms/context_transform.py @@ -0,0 +1,70 @@ +from typing import Any + +import numpy as np + +from minerva.transforms.transform import _Transform + + +class ClassRatioCrop(_Transform): + + def __init__( + self, + target_h_size: int, + target_w_size: int, + cat_max_ratio: float = 0.75, + max_attempts: int = 10, + ) -> None: + """Crop the input data to a target size, while keeping the ratio of classes in the image. + + Parameters + ---------- + target_h_size : int + The target height of the crop. + target_w_size : int + The target width of the crop. + cat_max_ratio : float, optional + The maximum ratio of pixels of a single class in the crop, by default 0.75 + max_attempts : int, optional + The maximum number of attempts to crop the image, by default 10 + """ + self.target_h_size = target_h_size + self.target_w_size = target_w_size + self.cat_max_ratio = cat_max_ratio + self.max_attempts = max_attempts + self.crop_coords = None + + def __call__(self, x: np.ndarray) -> np.ndarray: + h, w = x.shape[:2] + + if self.crop_coords is None: + if not issubclass(x.dtype.type, np.integer): + raise ValueError( + "You must provide a mask first to use this functionality. For that you enable support_context_transforms if your dataset supports it, or use a different dataset that does supports it." + ) + + for _ in range(self.max_attempts): + # Randomly select the top-left corner for the crop + top = np.random.randint(0, h - self.target_h_size + 1) + left = np.random.randint(0, w - self.target_w_size + 1) + + # Extract the crop from both image and label + cropped_image = x[ + top : top + self.target_h_size, left : left + self.target_w_size + ] + + # Calculate the proportion of the most frequent class in the crop + _, counts = np.unique(cropped_image, return_counts=True) + class_ratios = counts / (self.target_h_size * self.target_w_size) + + if np.max(class_ratios) <= self.cat_max_ratio: + self.crop_coords = (top, left) + return cropped_image + + # If no valid crop was found, return the last crop (without meeting the ratio constraint) + self.crop_coords = (top, left) + return cropped_image + + else: + top, left = self.crop_coords + self.crop_coords = None + return x[top : top + self.target_h_size, left : left + self.target_w_size] diff --git a/minerva/transforms/random_transform.py b/minerva/transforms/random_transform.py new file mode 100644 index 0000000..6107845 --- /dev/null +++ b/minerva/transforms/random_transform.py @@ -0,0 +1,120 @@ +import random +from typing import List, Optional, Tuple, Union + +import numpy as np + +from minerva.transforms.transform import Flip, Resize, _Transform + + +class EmptyTransform(_Transform): + """A transform that does nothing to the input data.""" + + def __call__(self, data): + return data + + +class _RandomSyncedTransform(_Transform): + """Orchestrate the application of a type of random transform to a list of data, ensuring that the same random state is used for all of them.""" + + def __init__(self, num_samples: int, seed: Optional[int] = None): + """Orchestrate the application of a type of random transform to a list of data, ensuring that the same random state is used for all of them. + + Parameters + ---------- + transform : _Transform + A transform that will be applied to the input data. + num_samples : int + The number of samples that will be transformed. + seed : Optional[int], optional + The seed that will be used to generate the random state, by default None. + """ + self.num_samples = num_samples + self.transformations_executed = 0 + self.rng = np.random.default_rng(seed) + self.transform = EmptyTransform() + + def __call__(self, data): + if self.transformations_executed == 0: + self.transform = self.select_transform(data) + self.transformations_executed += 1 + return self.transform(data) + else: + if self.transformations_executed == self.num_samples - 1: + self.transformations_executed = 0 + else: + self.transformations_executed += 1 + return self.transform(data) + + def select_transform(self, data) -> _Transform: + raise NotImplementedError( + "This method should be implemented by the child class." + ) + + +class RandomFlip(_RandomSyncedTransform): + + def __init__( + self, + num_samples: int, + possible_axis: Union[int, List[int]] = 0, + seed: Optional[int] = None, + ): + """A transform that flips the input data along a random axis. + + Parameters + ---------- + num_samples : int + The number of samples that will be transformed. + possible_axis : Union[int, List[int]], optional + Possible axis to be transformed, will be chosen at random, by default 0 + seed : Optional[int], optional + A seed to ensure deterministic run, by default None + """ + super().__init__(num_samples, seed) + self.possible_axis = possible_axis + + def select_transform(self, data): + """selects the transform to be applied to the data.""" + + if isinstance(self.possible_axis, int): + flip_axis = self.rng.choice([True, False]) + if flip_axis: + return Flip(axis=self.possible_axis) + + else: + flip_axis = [ + bool(self.rng.choice([True, False])) + for _ in range(len(self.possible_axis)) + ] + if True in flip_axis: + chosen_axis = [ + axis for axis, flip in zip(self.possible_axis, flip_axis) if flip + ] + return Flip(axis=chosen_axis) + + return EmptyTransform() + + +class RandomResize(_RandomSyncedTransform): + + def __init__( + self, + target_scale: Tuple[int, int], + ratio_range: Tuple[float, float], + num_samples: int, + seed: Optional[int] = None, + ): + super().__init__(num_samples, seed) + self.target_scale = target_scale + self.ratio_range = ratio_range + self.resize: Optional[_Transform] = None + + def select_transform(self, data): + orig_height, orig_width = data.shape[:2] + + # Apply a random scaling factor within the ratio range + scale_factor = self.rng.uniform(*self.ratio_range) + new_width = int(self.target_scale[1] * scale_factor) + new_height = int(self.target_scale[0] * scale_factor) + + return Resize(new_width, new_height) diff --git a/minerva/transforms/transform.py b/minerva/transforms/transform.py index 9d2dd28..48744ad 100644 --- a/minerva/transforms/transform.py +++ b/minerva/transforms/transform.py @@ -1,6 +1,7 @@ from itertools import product -from typing import Any, List, Sequence, Union +from typing import Any, List, Literal, Sequence, Tuple, Union +import cv2 import numpy as np import torch from perlin_noise import PerlinNoise @@ -69,16 +70,16 @@ def __call__(self, x: np.ndarray) -> np.ndarray: """ if isinstance(self.axis, int): - return np.flip(x, axis=self.axis) + return np.flip(x, axis=self.axis).copy() assert ( len(self.axis) <= x.ndim - ), "Axis list has more dimentions than input data. The lenth of axis needs to be less or equal to input dimentions." + ), "Axis list has more dimensions than input data. The length of axis needs to be less or equal to input dimensions." for axis in self.axis: x = np.flip(x, axis=axis) - return x + return x.copy() class PerlinMasker(_Transform): @@ -175,19 +176,193 @@ def __call__(self, x: np.ndarray) -> np.ndarray: class Padding(_Transform): - def __init__(self, target_h_size: int, target_w_size: int): + def __init__( + self, + target_h_size: int, + target_w_size: int, + padding_mode: Literal["reflect", "constant"] = "reflect", + constant_value: int = 0, + mask_value: int = 255, + ): self.target_h_size = target_h_size self.target_w_size = target_w_size + self.padding_mode = padding_mode + self.constant_value = constant_value + self.mask_value = mask_value def __call__(self, x: np.ndarray) -> np.ndarray: h, w = x.shape[:2] pad_h = max(0, self.target_h_size - h) pad_w = max(0, self.target_w_size - w) + is_label = True if x.dtype == np.uint8 else False + if len(x.shape) == 2: - padded = np.pad(x, ((0, pad_h), (0, pad_w)), mode="reflect") + if self.padding_mode == "reflect": + padded = np.pad(x, ((0, pad_h), (0, pad_w)), mode="reflect") + elif self.padding_mode == "constant": + if is_label: + padded = np.pad( + x, + ((0, pad_h), (0, pad_w)), + mode="constant", + constant_values=self.mask_value, + ) + else: + padded = np.pad( + x, + ((0, pad_h), (0, pad_w)), + mode="constant", + constant_values=self.constant_value, + ) padded = np.expand_dims(padded, axis=2) - else: - padded = np.pad(x, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect") - padded = np.transpose(padded, (2, 0, 1)) + else: + if self.padding_mode == "reflect": + padded = np.pad(x, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect") + elif self.padding_mode == "constant": + if is_label: + padded = np.pad( + x, + ((0, pad_h), (0, pad_w), (0, 0)), + mode="constant", + constant_values=self.mask_value, + ) + else: + padded = np.pad( + x, + ((0, pad_h), (0, pad_w), (0, 0)), + mode="constant", + constant_values=self.constant_value, + ) return padded + + +class Normalize(_Transform): + def __init__(self, mean, std, to_rgb=False, normalize_labels=False): + """ + Initialize the Normalize transform. + + Args: + means (list or tuple): A list or tuple containing the mean for each channel. + stds (list or tuple): A list or tuple containing the standard deviation for each channel. + to_rgb (bool): If True, convert the data from BGR to RGB. + """ + assert len(mean) == len( + std + ), "Means and standard deviations must have the same length." + self.mean = mean + self.std = std + self.to_rgb = to_rgb + self.normalize_labels = normalize_labels + + def __call__(self, data): + """ + Normalize the input data using the provided means and standard deviations. + + Args: + data (numpy.ndarray): Input data array of shape (C, H, W) where C is the number of channels. + + Returns: + numpy.ndarray: Normalized data. + """ + + is_label = True if data.dtype == np.uint8 else False + + if is_label and self.normalize_labels: + # Convert from gray scale (1 channel) to RGB (3 channels) if to_rgb is True + if self.to_rgb and data.shape[0] == 1: + data = np.repeat(data, 3, axis=0) + + assert data.shape[0] == len( + self.mean + ), f"Number of channels in data does not match the number of provided mean/std. {data.shape}" + + # Normalize each channel + for i in range(len(self.mean)): + data[i, :, :] = (data[i, :, :] - self.mean[i]) / self.std[i] + + return data + + +class Crop(_Transform): + def __init__( + self, + target_h_size: int, + target_w_size: int, + start_coord: Tuple[int, int] = (0, 0), + ): + self.target_h_size = target_h_size + self.target_w_size = target_w_size + self.start_coord = start_coord + + def __call__(self, x: np.ndarray) -> np.ndarray: + h, w = x.shape[:2] + start_h = (h - self.target_h_size) // 2 + start_w = (w - self.target_w_size) // 2 + end_h = start_h + self.target_h_size + end_w = start_w + self.target_w_size + if len(x.shape) == 2: + cropped = x[start_h:end_h, start_w:end_w] + cropped = np.expand_dims(cropped, axis=2) + else: + cropped = x[start_h:end_h, start_w:end_w] + + return cropped + + +class Transpose(_Transform): + """Reorder the axes of numpy arrays.""" + + def __init__(self, axes: Sequence[int]): + """Reorder the axes of numpy arrays. + + Parameters + ---------- + axes : int + The order of the new axes + """ + self.axes = axes + + def __call__(self, x: np.ndarray) -> np.ndarray: + """Reorder the axes of numpy arrays.""" + + if len(x.shape) == 2: + x = np.expand_dims(x, axis=2) + return np.transpose(x, self.axes) + + +class Resize(_Transform): + + def __init__( + self, + target_h_size: int, + target_w_size: int, + keep_aspect_ratio: bool = False, + ): + self.target_h_size = target_h_size + self.target_w_size = target_w_size + self.keep_aspect_ratio = keep_aspect_ratio + + def __call__(self, x: np.ndarray) -> np.ndarray: + original_height, original_width = x.shape[:2] + + if not self.keep_aspect_ratio: + # Direct resize without keeping the aspect ratio + return cv2.resize( + x, + (self.target_w_size, self.target_h_size), + interpolation=cv2.INTER_NEAREST, + ) + + # Calculate scaling factors for both dimensions + width_scale = self.target_w_size / original_width + height_scale = self.target_h_size / original_height + + # Choose the smaller scale to keep aspect ratio, and round down + scale = min(width_scale, height_scale) + + # Compute new dimensions, rounding down to match MMsegmentation's behavior + new_width = int(original_width * scale) + new_height = int(original_height * scale) + + return cv2.resize(x, (new_width, new_height), interpolation=cv2.INTER_NEAREST) diff --git a/minerva/utils/position_embedding.py b/minerva/utils/position_embedding.py index 0be7959..c0be963 100644 --- a/minerva/utils/position_embedding.py +++ b/minerva/utils/position_embedding.py @@ -1,7 +1,7 @@ from functools import partial +import lightning.pytorch as pl import numpy as np -import pytorch_lightning as pl import torch import torch.nn as nn from timm.models.vision_transformer import Block, PatchEmbed diff --git a/pyproject.toml b/pyproject.toml index 4248306..60817f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,8 +32,10 @@ version = "0.2.2-beta" dependencies = [ "gitpython", + "jsonargparse", + "ray[tune]", "jsonargparse>=4.27", - "lightning>=2.1.9", + "lightning==2.2.0", "numpy>=1.23.5", "pandas>=2.2.2", "perlin-noise>=1.12", @@ -45,13 +47,15 @@ dependencies = [ "tifffile>=2024", "timm>=0.9", "torch>=2.0.8", - "torchmetrics>=1.3.0", + "torchmetrics==1.3.1", "torchvision>=0.15", - "zarr>=2.17" + "zarr>=2.17", + "hyperopt>=0.2.5", ] -[tool.setuptools] -packages = ["minerva"] +[tool.setuptools.packages.find] +where = ["."] +include = ["minerva*"] [project.optional-dependencies] dev = ["mock", "pytest", "black", "isort"] @@ -63,7 +67,7 @@ docs = [ "sphinx-rtd-theme", "sphinx-autodoc-typehints", "sphinx-argparse", - "sphinx-autoapi" + "sphinx-autoapi", ] [project.urls] diff --git a/tests/transforms/test_random_flip.py b/tests/transforms/test_random_flip.py new file mode 100644 index 0000000..51d10c7 --- /dev/null +++ b/tests/transforms/test_random_flip.py @@ -0,0 +1,79 @@ +import numpy as np +import pytest + +from minerva.transforms.random_transform import RandomFlip + + +def test_random_flip_single_axis_with_flip(): + # Create a dummy input + x = np.random.rand(10, 20) + + # Apply the flip transform along the first axis + flip_transform = RandomFlip(possible_axis=0, num_samples=1, seed=0) + flipped_x = flip_transform(x) + + # Check if the flipped data has the same shape as the input + assert flipped_x.shape == x.shape + + # Check if the flipped data is different from the input + assert np.allclose(flipped_x, np.flip(x, axis=0)) + + +def test_random_flip_single_axis_without_flip(): + # Create a dummy input + x = np.random.rand(10, 20) + + # Apply the flip transform along the first axis + flip_transform = RandomFlip(possible_axis=0, num_samples=1, seed=1) + flipped_x = flip_transform(x) + + # Check if the flipped data has the same shape as the input + assert flipped_x.shape == x.shape + + # Check if the flipped data is different from the input + assert np.allclose(flipped_x, x) + + +def test_random_flip_first_axis(): + # Create a dummy input + x = np.random.rand(10, 20, 30) + + # Apply the flip transform along multiple axes + flip_transform = RandomFlip(possible_axis=[0, 1], num_samples=1, seed=0) + flipped_x = flip_transform(x) + + # Check if the flipped data has the same shape as the input + assert flipped_x.shape == x.shape + + # check if only the first axis is flipped + assert np.allclose(flipped_x, np.flip(x, axis=0)) + + +def test_random_flip_second_axis(): + # Create a dummy input + x = np.random.rand(10, 20, 30) + + # Apply the flip transform along multiple axes + flip_transform = RandomFlip(possible_axis=[0, 1], num_samples=1, seed=1) + flipped_x = flip_transform(x) + + # Check if the flipped data has the same shape as the input + assert flipped_x.shape == x.shape + + # check if the second axis is flipped + assert np.allclose(flipped_x, np.flip(x, axis=1)) + + +def test_random_flip_two_axis(): + # Create a dummy input + x = np.random.rand(10, 20, 30) + + # Apply the flip transform along multiple axes + flip_transform = RandomFlip(possible_axis=[0, 1], num_samples=1, seed=2) + flipped_x = flip_transform(x) + + # Check if the flipped data has the same shape as the input + assert flipped_x.shape == x.shape + + # check if both axis are flipped + assert np.allclose(flipped_x, np.flip(x, axis=(0, 1)))