Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add support for segmentation with different mask extensions #1130

Merged
merged 3 commits into from
Jan 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for `from_csv` and `from_data_frame` to `VideoClassificationData` ([#1117](https://github.com/PyTorchLightning/lightning-flash/pull/1117))

- Added support for `SemanticSegmentationData.from_folders` where mask files have different extensions to the image files ([#1130](https://github.com/PyTorchLightning/lightning-flash/pull/1130))

### Changed

- Changed `Wav2Vec2Processor` to `AutoProcessor` and seperate it from backbone [optional] ([#1075](https://github.com/PyTorchLightning/lightning-flash/pull/1075))
Expand Down
3 changes: 2 additions & 1 deletion flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ def from_folders(
├── image_3.png
...

your ``train_masks`` folder (passed to the ``train_target_folder`` argument) would need to look like this:
your ``train_masks`` folder (passed to the ``train_target_folder`` argument) would need to look like this
(although the file extensions could be different):

.. code-block::

Expand Down
31 changes: 13 additions & 18 deletions flash/image/segmentation/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from pytorch_lightning.utilities import rank_zero_warn

from flash.core.data.io.input import DataKeys, ImageLabelsMap, Input
from flash.core.data.utilities.paths import filter_valid_files, PATH_TYPE
Expand All @@ -33,8 +32,7 @@
SampleCollection = None

if _TORCHVISION_AVAILABLE:
import torchvision
import torchvision.transforms.functional as FT
from torchvision.transforms.functional import to_tensor


class SemanticSegmentationInput(Input):
Expand Down Expand Up @@ -104,9 +102,9 @@ def load_data(

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
filepath = sample[DataKeys.INPUT]
sample[DataKeys.INPUT] = FT.to_tensor(image_loader(filepath))
sample[DataKeys.INPUT] = to_tensor(image_loader(filepath))
if DataKeys.TARGET in sample:
sample[DataKeys.TARGET] = torchvision.io.read_image(sample[DataKeys.TARGET])[0]
sample[DataKeys.TARGET] = (to_tensor(image_loader(sample[DataKeys.TARGET])) * 255).long()[0]
sample = super().load_sample(sample)
sample[DataKeys.METADATA]["filepath"] = filepath
return sample
Expand All @@ -124,20 +122,17 @@ def load_data(
files = os.listdir(folder)
files.sort()
if mask_folder is not None:
mask_files = os.listdir(mask_folder)

all_files = set(files).intersection(set(mask_files))
if len(all_files) != len(files) or len(all_files) != len(mask_files):
rank_zero_warn(
f"Found inconsistent files in input folder: {folder} and mask folder: {mask_folder}. Some files"
" have been dropped.",
UserWarning,
mask_files = {os.path.splitext(file)[0]: file for file in os.listdir(mask_folder)}
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
file_names = [os.path.splitext(file)[0] for file in files]

if len(set(file_names) - mask_files.keys()) != 0:
raise ValueError(
f"Found inconsistent files in input folder: {folder} and mask folder: {mask_folder}. All input "
f"files must have a corresponding mask file with the same name."
)

files = [os.path.join(folder, file) for file in all_files]
mask_files = [os.path.join(mask_folder, file) for file in all_files]
files.sort()
mask_files.sort()
files = [os.path.join(folder, file) for file in files]
mask_files = [os.path.join(mask_folder, mask_files[file_name]) for file_name in file_names]
return super().load_data(files, mask_files)
return super().load_data([os.path.join(folder, file) for file in files])

Expand Down Expand Up @@ -172,6 +167,6 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
class SemanticSegmentationDeserializer(ImageDeserializer):
def serve_load_sample(self, data: str) -> Dict[str, Any]:
result = super().serve_load_sample(data)
result[DataKeys.INPUT] = FT.to_tensor(result[DataKeys.INPUT])
result[DataKeys.INPUT] = to_tensor(result[DataKeys.INPUT])
result[DataKeys.METADATA] = {"size": result[DataKeys.INPUT].shape[-2:]}
return result
76 changes: 65 additions & 11 deletions tests/image/segmentation/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,69 @@ def test_from_folders(tmpdir):

@staticmethod
@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_from_folders_warning(tmpdir):
def test_from_folders_different_extensions(tmpdir):
tmp_dir = Path(tmpdir)

# create random dummy data

os.makedirs(str(tmp_dir / "images"))
os.makedirs(str(tmp_dir / "targets"))

images = [
str(tmp_dir / "images" / "img1.png"),
str(tmp_dir / "images" / "img2.png"),
str(tmp_dir / "images" / "img3.png"),
]

targets = [
str(tmp_dir / "targets" / "img1.tiff"),
str(tmp_dir / "targets" / "img2.tiff"),
str(tmp_dir / "targets" / "img3.tiff"),
]

num_classes: int = 2
img_size: Tuple[int, int] = (128, 128)
create_random_data(images, targets, img_size, num_classes)

# instantiate the data module

dm = SemanticSegmentationData.from_folders(
train_folder=str(tmp_dir / "images"),
train_target_folder=str(tmp_dir / "targets"),
val_folder=str(tmp_dir / "images"),
val_target_folder=str(tmp_dir / "targets"),
test_folder=str(tmp_dir / "images"),
test_target_folder=str(tmp_dir / "targets"),
batch_size=2,
num_workers=0,
num_classes=num_classes,
)
assert dm is not None
assert dm.train_dataloader() is not None
assert dm.val_dataloader() is not None
assert dm.test_dataloader() is not None

# check training data
data = next(iter(dm.train_dataloader()))
imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET]
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2, 128, 128)

# check val data
data = next(iter(dm.val_dataloader()))
imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET]
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2, 128, 128)

# check test data
data = next(iter(dm.test_dataloader()))
imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET]
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2, 128, 128)

@staticmethod
@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_from_folders_error(tmpdir):
tmp_dir = Path(tmpdir)

# create random dummy data
Expand All @@ -150,22 +212,14 @@ def test_from_folders_warning(tmpdir):

# instantiate the data module

with pytest.warns(UserWarning, match="Found inconsistent files"):
dm = SemanticSegmentationData.from_folders(
with pytest.raises(ValueError, match="Found inconsistent files"):
SemanticSegmentationData.from_folders(
train_folder=str(tmp_dir / "images"),
train_target_folder=str(tmp_dir / "targets"),
batch_size=1,
num_workers=0,
num_classes=num_classes,
)
assert dm is not None
assert dm.train_dataloader() is not None

# check training data
data = next(iter(dm.train_dataloader()))
imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET]
assert imgs.shape == (1, 3, 128, 128)
assert labels.shape == (1, 128, 128)

@staticmethod
@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
Expand Down