Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supervised Reconstruction Dataset + bug fixes and formatting #25

Merged
merged 5 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions sslt/data/datasets/supervised_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import List, Tuple

import numpy as np
from base import SimpleDataset

from sslt.data.readers.reader import _Reader
from sslt.transforms.transform import _Transform


class SupervisedReconstructionDataset(SimpleDataset):
"""A simple dataset class for supervised reconstruction tasks.

In summary, each element of the dataset is a pair of data, where the first
element is the input data and the second element is the target data.
Usually, both input and target data have the same shape.

This dataset is useful for supervised tasks such as image reconstruction,
segmantic segmentation, and object detection, where the input data is the
original data and the target is a mask or a segmentation map.

Examples
--------

1. Semantic Segmentation Dataset:

```python
from sslt.data.readers import ImageReader
from sslt.transforms import ImageTransform
from sslt.data.datasets import SupervisedReconstructionDataset

# Create the readers
image_reader = ImageReader("path/to/images")
mask_reader = ImageReader("path/to/masks")

# Create the transforms
image_transform = ImageTransform()

# Create the dataset
dataset = SupervisedReconstructionDataset(
readers=[image_reader, mask_reader],
transforms=image_transform
)

# Load the first sample
dataset[0] # Returns a tuple: (image, mask)
```
"""

def __init__(
self, readers: List[_Reader], transforms: _Transform | None = None
):
"""A simple dataset class for supervised reconstruction tasks.

Parameters
----------
readers: List[_Reader]
List of data readers. It must contain exactly 2 readers.
The first reader for the input data and the second reader for the
target data.
transforms: _Transform | None
Optional data transformation pipeline.

Raises
-------
AssertionError: If the number of readers is not exactly 2.
"""
super().__init__(readers, transforms)

assert (
len(self.readers) == 2
), "SupervisedReconstructionDataset requires exactly 2 readers"

def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SimpleDataset already implements the equivaltent __getitem__code to this one.
Thus, you can omit this getitem implementation.
It worth noticing that we cannnot infer the return type (inside the tuple) yet, as it depends of the reader. The second element could be an int instead of np.ndarray, that represents the label, for instance.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal with this dataset is to be as specific as possible, hence the Transform Pipeline and the (exactly) two readers. The implementation is almost the same but it uses only one transform for both data points and it returns a known type as its output (the numpy array tuple), which is better for code downstream in the full pipeline process. Sure, as it is implemented, I can’t be sure what types are returned by the readers but if that’s a problem I would rather put a check to ensure it then to return Any.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @GabrielBG0. I did not see the class name, my bad.

  1. I agree with a more specific implementation. However, I think this class could be a bit more generic. The name suggests that this dataset would only be used for Semantic Segmentation classes, which is not true. In fact, the same implementation here would be used for any other reconstruction task (predicting seismic attributes, seismic facies classification/segmentation, etc). Thus, any task that takes an input and has a target with the same shape as input should subclass this one. Therefore, maybe we can change its name to something like SupervisedReconstructionDataset. What do you think?
  2. I agree with the typing hint, as it is a specific class. However, it is worth notice that, yes, this is the same behavior as in the base class. If you have only one Transform and several readers, the same transform will be applied to all data fetched from the readers (equivalent to the codehere). We can still rewrite this whole __getitem__ impementation to a simple super().__getitem__(index). This would reduce code duplication and unit test cases.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, what do you think about it now?

"""Load data from sources and apply specified transforms. The same
transform is applied to both input and target data.

Parameters
----------
index : int
The index of the sample to load.

Returns
-------
Tuple[np.ndarray, np.ndarray]
A tuple containing two numpy arrays representing the data.

"""
data = super().__getitem__(index)

return (data[0], data[1])
17 changes: 9 additions & 8 deletions sslt/data/readers/png_reader.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
from pathlib import Path
from typing import Union
from reader import _Reader

import numpy as np
from PIL import Image
from reader import _Reader


class PNGReader(_Reader):
"""This class loads a PNG file from a directory. It assumes that the PNG
files are named with a number as the filename, starting from 0. This is
shown below.
"""This class loads a PNG file from a directory. It assumes that the PNG
files are named with a number as the filename, starting from 0. This is
shown below.

```
/path/
├── 0.png
├── 1.png
├── 2.png
└── ...
```

Thus, the element at index `i` will be the file `i.png`.
"""

Expand All @@ -35,7 +36,7 @@ def __init__(self, path: Union[Path, str]):
self.len = len(list(self.path.glob("*.png")))

def __getitem__(self, index: int) -> np.ndarray:
"""Retrieve the PNG file at the specified index. The index will be
"""Retrieve the PNG file at the specified index. The index will be
used as the filename of the PNG file.

Parameters
Expand Down Expand Up @@ -66,4 +67,4 @@ def __len__(self) -> int:
int
The number of PNG files in the directory.
"""
return len(self.files)
return self.len
7 changes: 4 additions & 3 deletions sslt/data/readers/reader.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from typing import Any

import numpy as np


class _Reader:
"""
Base class for readers. Readers define an ordered collection of data and
Base class for readers. Readers define an ordered collection of data and
provide methods to access it. This class primarily handles:

1. Definition of data structure and storage.
2. Reading data from the source.

The access is handled by the __getitem__ and __len__ methods, which should be
implemented by a subclass. Readers usually returns a single item at a time,
that can be a single image, a single label, etc.
Expand Down
21 changes: 12 additions & 9 deletions sslt/data/readers/tiff_reader.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,36 @@
from pathlib import Path
from typing import Union
from sslt.data.readers.reader import _Reader

import numpy as np
import tifffile as tiff
from pathlib import Path

from sslt.data.readers.reader import _Reader


class TiffReader(_Reader):
"""This class loads a TIFF file from a directory. It assumes that the TIFF
files are named with a number as the filename, starting from 0. This is
shown below.
"""This class loads a TIFF file from a directory. It assumes that the TIFF
files are named with a number as the filename, starting from 0. This is
shown below.

```
/path/
├── 0.tiff
├── 1.tiff
├── 2.tiff
└── ...
```

Thus, the element at index `i` will be the file `i.tiff`.
"""

def __init__(self, path: str):
self.path = Path(path)
if not self.path.is_dir():
raise ValueError(f"Path {path} is not a directory")
self.len = len(list(self.path.glob("*.tif")))

def __getitem__(self, index: Union[int, slice]) -> np.ndarray:
"""Retrieve the TIFF file at the specified index. The index will be
"""Retrieve the TIFF file at the specified index. The index will be
used as the filename of the TIFF file.

Parameters
Expand Down Expand Up @@ -58,4 +61,4 @@ def __len__(self) -> int:
int
The number of TIFF files in the directory.
"""
return len(self.files)
return self.len
12 changes: 6 additions & 6 deletions sslt/transforms/transform.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@

from typing import Any, List


class _Transform:
"""This class is a base class for all transforms. Transforms is just a
fancy word for a function that takes an input and returns an output. The
input and output can be anything. However, transforms operates over a
"""This class is a base class for all transforms. Transforms is just a
fancy word for a function that takes an input and returns an output. The
input and output can be anything. However, transforms operates over a
single sample of data and does not require any additional information to
perform the transformation. The __call__ method should be overridden in
subclasses to define the transformation logic.
"""

def __call__(self, *args, **kwargs) -> Any:
"""Implement the transformation logic in this method. Usually, the
"""Implement the transformation logic in this method. Usually, the
transformation is applyied on a single sample of data.
"""
raise NotImplementedError


class TransformPipeline(_Transform):
"""Apply a sequence of transforms to a single sample of data and return the
"""Apply a sequence of transforms to a single sample of data and return the
transformed data.
"""

def __init__(self, transforms: List[_Transform]):
"""Apply a sequence of transforms to a single sample of data and return
the transformed data.
Expand Down