Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add support for segmentation with different mask extensions (#1130)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Jan 20, 2022
1 parent 835d32b commit 8b244d7
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 30 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for `from_csv` and `from_data_frame` to `VideoClassificationData` ([#1117](https://github.com/PyTorchLightning/lightning-flash/pull/1117))

- Added support for `SemanticSegmentationData.from_folders` where mask files have different extensions to the image files ([#1130](https://github.com/PyTorchLightning/lightning-flash/pull/1130))

### Changed

- Changed `Wav2Vec2Processor` to `AutoProcessor` and seperate it from backbone [optional] ([#1075](https://github.com/PyTorchLightning/lightning-flash/pull/1075))
Expand Down
3 changes: 2 additions & 1 deletion flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ def from_folders(
├── image_3.png
...
your ``train_masks`` folder (passed to the ``train_target_folder`` argument) would need to look like this:
your ``train_masks`` folder (passed to the ``train_target_folder`` argument) would need to look like this
(although the file extensions could be different):
.. code-block::
Expand Down
31 changes: 13 additions & 18 deletions flash/image/segmentation/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from pytorch_lightning.utilities import rank_zero_warn

from flash.core.data.io.input import DataKeys, ImageLabelsMap, Input
from flash.core.data.utilities.paths import filter_valid_files, PATH_TYPE
Expand All @@ -33,8 +32,7 @@
SampleCollection = None

if _TORCHVISION_AVAILABLE:
import torchvision
import torchvision.transforms.functional as FT
from torchvision.transforms.functional import to_tensor


class SemanticSegmentationInput(Input):
Expand Down Expand Up @@ -104,9 +102,9 @@ def load_data(

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
filepath = sample[DataKeys.INPUT]
sample[DataKeys.INPUT] = FT.to_tensor(image_loader(filepath))
sample[DataKeys.INPUT] = to_tensor(image_loader(filepath))
if DataKeys.TARGET in sample:
sample[DataKeys.TARGET] = torchvision.io.read_image(sample[DataKeys.TARGET])[0]
sample[DataKeys.TARGET] = (to_tensor(image_loader(sample[DataKeys.TARGET])) * 255).long()[0]
sample = super().load_sample(sample)
sample[DataKeys.METADATA]["filepath"] = filepath
return sample
Expand All @@ -124,20 +122,17 @@ def load_data(
files = os.listdir(folder)
files.sort()
if mask_folder is not None:
mask_files = os.listdir(mask_folder)

all_files = set(files).intersection(set(mask_files))
if len(all_files) != len(files) or len(all_files) != len(mask_files):
rank_zero_warn(
f"Found inconsistent files in input folder: {folder} and mask folder: {mask_folder}. Some files"
" have been dropped.",
UserWarning,
mask_files = {os.path.splitext(file)[0]: file for file in os.listdir(mask_folder)}
file_names = [os.path.splitext(file)[0] for file in files]

if len(set(file_names) - mask_files.keys()) != 0:
raise ValueError(
f"Found inconsistent files in input folder: {folder} and mask folder: {mask_folder}. All input "
f"files must have a corresponding mask file with the same name."
)

files = [os.path.join(folder, file) for file in all_files]
mask_files = [os.path.join(mask_folder, file) for file in all_files]
files.sort()
mask_files.sort()
files = [os.path.join(folder, file) for file in files]
mask_files = [os.path.join(mask_folder, mask_files[file_name]) for file_name in file_names]
return super().load_data(files, mask_files)
return super().load_data([os.path.join(folder, file) for file in files])

Expand Down Expand Up @@ -172,6 +167,6 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
class SemanticSegmentationDeserializer(ImageDeserializer):
def serve_load_sample(self, data: str) -> Dict[str, Any]:
result = super().serve_load_sample(data)
result[DataKeys.INPUT] = FT.to_tensor(result[DataKeys.INPUT])
result[DataKeys.INPUT] = to_tensor(result[DataKeys.INPUT])
result[DataKeys.METADATA] = {"size": result[DataKeys.INPUT].shape[-2:]}
return result
76 changes: 65 additions & 11 deletions tests/image/segmentation/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,69 @@ def test_from_folders(tmpdir):

@staticmethod
@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_from_folders_warning(tmpdir):
def test_from_folders_different_extensions(tmpdir):
tmp_dir = Path(tmpdir)

# create random dummy data

os.makedirs(str(tmp_dir / "images"))
os.makedirs(str(tmp_dir / "targets"))

images = [
str(tmp_dir / "images" / "img1.png"),
str(tmp_dir / "images" / "img2.png"),
str(tmp_dir / "images" / "img3.png"),
]

targets = [
str(tmp_dir / "targets" / "img1.tiff"),
str(tmp_dir / "targets" / "img2.tiff"),
str(tmp_dir / "targets" / "img3.tiff"),
]

num_classes: int = 2
img_size: Tuple[int, int] = (128, 128)
create_random_data(images, targets, img_size, num_classes)

# instantiate the data module

dm = SemanticSegmentationData.from_folders(
train_folder=str(tmp_dir / "images"),
train_target_folder=str(tmp_dir / "targets"),
val_folder=str(tmp_dir / "images"),
val_target_folder=str(tmp_dir / "targets"),
test_folder=str(tmp_dir / "images"),
test_target_folder=str(tmp_dir / "targets"),
batch_size=2,
num_workers=0,
num_classes=num_classes,
)
assert dm is not None
assert dm.train_dataloader() is not None
assert dm.val_dataloader() is not None
assert dm.test_dataloader() is not None

# check training data
data = next(iter(dm.train_dataloader()))
imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET]
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2, 128, 128)

# check val data
data = next(iter(dm.val_dataloader()))
imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET]
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2, 128, 128)

# check test data
data = next(iter(dm.test_dataloader()))
imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET]
assert imgs.shape == (2, 3, 128, 128)
assert labels.shape == (2, 128, 128)

@staticmethod
@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_from_folders_error(tmpdir):
tmp_dir = Path(tmpdir)

# create random dummy data
Expand All @@ -150,22 +212,14 @@ def test_from_folders_warning(tmpdir):

# instantiate the data module

with pytest.warns(UserWarning, match="Found inconsistent files"):
dm = SemanticSegmentationData.from_folders(
with pytest.raises(ValueError, match="Found inconsistent files"):
SemanticSegmentationData.from_folders(
train_folder=str(tmp_dir / "images"),
train_target_folder=str(tmp_dir / "targets"),
batch_size=1,
num_workers=0,
num_classes=num_classes,
)
assert dm is not None
assert dm.train_dataloader() is not None

# check training data
data = next(iter(dm.train_dataloader()))
imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET]
assert imgs.shape == (1, 3, 128, 128)
assert labels.shape == (1, 128, 128)

@staticmethod
@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
Expand Down

0 comments on commit 8b244d7

Please sign in to comment.