diff --git a/sslt/data/datasets/supervised_dataset.py b/sslt/data/datasets/supervised_dataset.py new file mode 100644 index 0000000..cbe478b --- /dev/null +++ b/sslt/data/datasets/supervised_dataset.py @@ -0,0 +1,90 @@ +from typing import List, Tuple + +import numpy as np +from base import SimpleDataset + +from sslt.data.readers.reader import _Reader +from sslt.transforms.transform import _Transform + + +class SupervisedReconstructionDataset(SimpleDataset): + """A simple dataset class for supervised reconstruction tasks. + + In summary, each element of the dataset is a pair of data, where the first + element is the input data and the second element is the target data. + Usually, both input and target data have the same shape. + + This dataset is useful for supervised tasks such as image reconstruction, + segmantic segmentation, and object detection, where the input data is the + original data and the target is a mask or a segmentation map. + + Examples + -------- + + 1. Semantic Segmentation Dataset: + + ```python + from sslt.data.readers import ImageReader + from sslt.transforms import ImageTransform + from sslt.data.datasets import SupervisedReconstructionDataset + + # Create the readers + image_reader = ImageReader("path/to/images") + mask_reader = ImageReader("path/to/masks") + + # Create the transforms + image_transform = ImageTransform() + + # Create the dataset + dataset = SupervisedReconstructionDataset( + readers=[image_reader, mask_reader], + transforms=image_transform + ) + + # Load the first sample + dataset[0] # Returns a tuple: (image, mask) + ``` + """ + + def __init__( + self, readers: List[_Reader], transforms: _Transform | None = None + ): + """A simple dataset class for supervised reconstruction tasks. + + Parameters + ---------- + readers: List[_Reader] + List of data readers. It must contain exactly 2 readers. + The first reader for the input data and the second reader for the + target data. + transforms: _Transform | None + Optional data transformation pipeline. + + Raises + ------- + AssertionError: If the number of readers is not exactly 2. + """ + super().__init__(readers, transforms) + + assert ( + len(self.readers) == 2 + ), "SupervisedReconstructionDataset requires exactly 2 readers" + + def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]: + """Load data from sources and apply specified transforms. The same + transform is applied to both input and target data. + + Parameters + ---------- + index : int + The index of the sample to load. + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + A tuple containing two numpy arrays representing the data. + + """ + data = super().__getitem__(index) + + return (data[0], data[1]) diff --git a/sslt/data/readers/png_reader.py b/sslt/data/readers/png_reader.py index 19e6182..65cb7ee 100644 --- a/sslt/data/readers/png_reader.py +++ b/sslt/data/readers/png_reader.py @@ -1,15 +1,16 @@ from pathlib import Path from typing import Union -from reader import _Reader + import numpy as np from PIL import Image +from reader import _Reader class PNGReader(_Reader): - """This class loads a PNG file from a directory. It assumes that the PNG - files are named with a number as the filename, starting from 0. This is - shown below. - + """This class loads a PNG file from a directory. It assumes that the PNG + files are named with a number as the filename, starting from 0. This is + shown below. + ``` /path/ ├── 0.png @@ -17,7 +18,7 @@ class PNGReader(_Reader): ├── 2.png └── ... ``` - + Thus, the element at index `i` will be the file `i.png`. """ @@ -35,7 +36,7 @@ def __init__(self, path: Union[Path, str]): self.len = len(list(self.path.glob("*.png"))) def __getitem__(self, index: int) -> np.ndarray: - """Retrieve the PNG file at the specified index. The index will be + """Retrieve the PNG file at the specified index. The index will be used as the filename of the PNG file. Parameters @@ -66,4 +67,4 @@ def __len__(self) -> int: int The number of PNG files in the directory. """ - return len(self.files) + return self.len diff --git a/sslt/data/readers/reader.py b/sslt/data/readers/reader.py index 5a02a5f..0ab5a5b 100644 --- a/sslt/data/readers/reader.py +++ b/sslt/data/readers/reader.py @@ -1,15 +1,16 @@ from typing import Any + import numpy as np class _Reader: """ - Base class for readers. Readers define an ordered collection of data and + Base class for readers. Readers define an ordered collection of data and provide methods to access it. This class primarily handles: - + 1. Definition of data structure and storage. 2. Reading data from the source. - + The access is handled by the __getitem__ and __len__ methods, which should be implemented by a subclass. Readers usually returns a single item at a time, that can be a single image, a single label, etc. diff --git a/sslt/data/readers/tiff_reader.py b/sslt/data/readers/tiff_reader.py index 4e2098f..cce1027 100644 --- a/sslt/data/readers/tiff_reader.py +++ b/sslt/data/readers/tiff_reader.py @@ -1,15 +1,17 @@ +from pathlib import Path from typing import Union -from sslt.data.readers.reader import _Reader + import numpy as np import tifffile as tiff -from pathlib import Path + +from sslt.data.readers.reader import _Reader class TiffReader(_Reader): - """This class loads a TIFF file from a directory. It assumes that the TIFF - files are named with a number as the filename, starting from 0. This is - shown below. - + """This class loads a TIFF file from a directory. It assumes that the TIFF + files are named with a number as the filename, starting from 0. This is + shown below. + ``` /path/ ├── 0.tiff @@ -17,9 +19,10 @@ class TiffReader(_Reader): ├── 2.tiff └── ... ``` - + Thus, the element at index `i` will be the file `i.tiff`. """ + def __init__(self, path: str): self.path = Path(path) if not self.path.is_dir(): @@ -27,7 +30,7 @@ def __init__(self, path: str): self.len = len(list(self.path.glob("*.tif"))) def __getitem__(self, index: Union[int, slice]) -> np.ndarray: - """Retrieve the TIFF file at the specified index. The index will be + """Retrieve the TIFF file at the specified index. The index will be used as the filename of the TIFF file. Parameters @@ -58,4 +61,4 @@ def __len__(self) -> int: int The number of TIFF files in the directory. """ - return len(self.files) + return self.len diff --git a/sslt/transforms/transform.py b/sslt/transforms/transform.py index 01375c3..00051a3 100644 --- a/sslt/transforms/transform.py +++ b/sslt/transforms/transform.py @@ -1,27 +1,27 @@ - from typing import Any, List class _Transform: - """This class is a base class for all transforms. Transforms is just a - fancy word for a function that takes an input and returns an output. The - input and output can be anything. However, transforms operates over a + """This class is a base class for all transforms. Transforms is just a + fancy word for a function that takes an input and returns an output. The + input and output can be anything. However, transforms operates over a single sample of data and does not require any additional information to perform the transformation. The __call__ method should be overridden in subclasses to define the transformation logic. """ def __call__(self, *args, **kwargs) -> Any: - """Implement the transformation logic in this method. Usually, the + """Implement the transformation logic in this method. Usually, the transformation is applyied on a single sample of data. """ raise NotImplementedError class TransformPipeline(_Transform): - """Apply a sequence of transforms to a single sample of data and return the + """Apply a sequence of transforms to a single sample of data and return the transformed data. """ + def __init__(self, transforms: List[_Transform]): """Apply a sequence of transforms to a single sample of data and return the transformed data.