Skip to content

Commit

Permalink
Merge pull request #25 from GabrielBG0/14-supervised-dataset
Browse files Browse the repository at this point in the history
Supervised Reconstruction Dataset + bug fixes and formatting
  • Loading branch information
GabrielBG0 authored Feb 13, 2024
2 parents 945b71d + ffb620f commit 1aa1f34
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 26 deletions.
90 changes: 90 additions & 0 deletions sslt/data/datasets/supervised_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import List, Tuple

import numpy as np
from base import SimpleDataset

from sslt.data.readers.reader import _Reader
from sslt.transforms.transform import _Transform


class SupervisedReconstructionDataset(SimpleDataset):
"""A simple dataset class for supervised reconstruction tasks.
In summary, each element of the dataset is a pair of data, where the first
element is the input data and the second element is the target data.
Usually, both input and target data have the same shape.
This dataset is useful for supervised tasks such as image reconstruction,
segmantic segmentation, and object detection, where the input data is the
original data and the target is a mask or a segmentation map.
Examples
--------
1. Semantic Segmentation Dataset:
```python
from sslt.data.readers import ImageReader
from sslt.transforms import ImageTransform
from sslt.data.datasets import SupervisedReconstructionDataset
# Create the readers
image_reader = ImageReader("path/to/images")
mask_reader = ImageReader("path/to/masks")
# Create the transforms
image_transform = ImageTransform()
# Create the dataset
dataset = SupervisedReconstructionDataset(
readers=[image_reader, mask_reader],
transforms=image_transform
)
# Load the first sample
dataset[0] # Returns a tuple: (image, mask)
```
"""

def __init__(
self, readers: List[_Reader], transforms: _Transform | None = None
):
"""A simple dataset class for supervised reconstruction tasks.
Parameters
----------
readers: List[_Reader]
List of data readers. It must contain exactly 2 readers.
The first reader for the input data and the second reader for the
target data.
transforms: _Transform | None
Optional data transformation pipeline.
Raises
-------
AssertionError: If the number of readers is not exactly 2.
"""
super().__init__(readers, transforms)

assert (
len(self.readers) == 2
), "SupervisedReconstructionDataset requires exactly 2 readers"

def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
"""Load data from sources and apply specified transforms. The same
transform is applied to both input and target data.
Parameters
----------
index : int
The index of the sample to load.
Returns
-------
Tuple[np.ndarray, np.ndarray]
A tuple containing two numpy arrays representing the data.
"""
data = super().__getitem__(index)

return (data[0], data[1])
17 changes: 9 additions & 8 deletions sslt/data/readers/png_reader.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
from pathlib import Path
from typing import Union
from reader import _Reader

import numpy as np
from PIL import Image
from reader import _Reader


class PNGReader(_Reader):
"""This class loads a PNG file from a directory. It assumes that the PNG
files are named with a number as the filename, starting from 0. This is
shown below.
"""This class loads a PNG file from a directory. It assumes that the PNG
files are named with a number as the filename, starting from 0. This is
shown below.
```
/path/
├── 0.png
├── 1.png
├── 2.png
└── ...
```
Thus, the element at index `i` will be the file `i.png`.
"""

Expand All @@ -35,7 +36,7 @@ def __init__(self, path: Union[Path, str]):
self.len = len(list(self.path.glob("*.png")))

def __getitem__(self, index: int) -> np.ndarray:
"""Retrieve the PNG file at the specified index. The index will be
"""Retrieve the PNG file at the specified index. The index will be
used as the filename of the PNG file.
Parameters
Expand Down Expand Up @@ -66,4 +67,4 @@ def __len__(self) -> int:
int
The number of PNG files in the directory.
"""
return len(self.files)
return self.len
7 changes: 4 additions & 3 deletions sslt/data/readers/reader.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from typing import Any

import numpy as np


class _Reader:
"""
Base class for readers. Readers define an ordered collection of data and
Base class for readers. Readers define an ordered collection of data and
provide methods to access it. This class primarily handles:
1. Definition of data structure and storage.
2. Reading data from the source.
The access is handled by the __getitem__ and __len__ methods, which should be
implemented by a subclass. Readers usually returns a single item at a time,
that can be a single image, a single label, etc.
Expand Down
21 changes: 12 additions & 9 deletions sslt/data/readers/tiff_reader.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,36 @@
from pathlib import Path
from typing import Union
from sslt.data.readers.reader import _Reader

import numpy as np
import tifffile as tiff
from pathlib import Path

from sslt.data.readers.reader import _Reader


class TiffReader(_Reader):
"""This class loads a TIFF file from a directory. It assumes that the TIFF
files are named with a number as the filename, starting from 0. This is
shown below.
"""This class loads a TIFF file from a directory. It assumes that the TIFF
files are named with a number as the filename, starting from 0. This is
shown below.
```
/path/
├── 0.tiff
├── 1.tiff
├── 2.tiff
└── ...
```
Thus, the element at index `i` will be the file `i.tiff`.
"""

def __init__(self, path: str):
self.path = Path(path)
if not self.path.is_dir():
raise ValueError(f"Path {path} is not a directory")
self.len = len(list(self.path.glob("*.tif")))

def __getitem__(self, index: Union[int, slice]) -> np.ndarray:
"""Retrieve the TIFF file at the specified index. The index will be
"""Retrieve the TIFF file at the specified index. The index will be
used as the filename of the TIFF file.
Parameters
Expand Down Expand Up @@ -58,4 +61,4 @@ def __len__(self) -> int:
int
The number of TIFF files in the directory.
"""
return len(self.files)
return self.len
12 changes: 6 additions & 6 deletions sslt/transforms/transform.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@

from typing import Any, List


class _Transform:
"""This class is a base class for all transforms. Transforms is just a
fancy word for a function that takes an input and returns an output. The
input and output can be anything. However, transforms operates over a
"""This class is a base class for all transforms. Transforms is just a
fancy word for a function that takes an input and returns an output. The
input and output can be anything. However, transforms operates over a
single sample of data and does not require any additional information to
perform the transformation. The __call__ method should be overridden in
subclasses to define the transformation logic.
"""

def __call__(self, *args, **kwargs) -> Any:
"""Implement the transformation logic in this method. Usually, the
"""Implement the transformation logic in this method. Usually, the
transformation is applyied on a single sample of data.
"""
raise NotImplementedError


class TransformPipeline(_Transform):
"""Apply a sequence of transforms to a single sample of data and return the
"""Apply a sequence of transforms to a single sample of data and return the
transformed data.
"""

def __init__(self, transforms: List[_Transform]):
"""Apply a sequence of transforms to a single sample of data and return
the transformed data.
Expand Down

0 comments on commit 1aa1f34

Please sign in to comment.