From 8b244d785c5569e9aa7d2b878a5f94af976d3f55 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 20 Jan 2022 12:48:44 +0000 Subject: [PATCH] Add support for segmentation with different mask extensions (#1130) --- CHANGELOG.md | 2 + flash/image/segmentation/data.py | 3 +- flash/image/segmentation/input.py | 31 +++++------ tests/image/segmentation/test_data.py | 76 +++++++++++++++++++++++---- 4 files changed, 82 insertions(+), 30 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d9ebc3baaf..fdee506f71 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `from_csv` and `from_data_frame` to `VideoClassificationData` ([#1117](https://github.com/PyTorchLightning/lightning-flash/pull/1117)) +- Added support for `SemanticSegmentationData.from_folders` where mask files have different extensions to the image files ([#1130](https://github.com/PyTorchLightning/lightning-flash/pull/1130)) + ### Changed - Changed `Wav2Vec2Processor` to `AutoProcessor` and seperate it from backbone [optional] ([#1075](https://github.com/PyTorchLightning/lightning-flash/pull/1075)) diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index f0ca3f0a79..4677688f5e 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -203,7 +203,8 @@ def from_folders( ├── image_3.png ... - your ``train_masks`` folder (passed to the ``train_target_folder`` argument) would need to look like this: + your ``train_masks`` folder (passed to the ``train_target_folder`` argument) would need to look like this + (although the file extensions could be different): .. code-block:: diff --git a/flash/image/segmentation/input.py b/flash/image/segmentation/input.py index 0a15b29512..d03bc5c0e0 100644 --- a/flash/image/segmentation/input.py +++ b/flash/image/segmentation/input.py @@ -15,7 +15,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch -from pytorch_lightning.utilities import rank_zero_warn from flash.core.data.io.input import DataKeys, ImageLabelsMap, Input from flash.core.data.utilities.paths import filter_valid_files, PATH_TYPE @@ -33,8 +32,7 @@ SampleCollection = None if _TORCHVISION_AVAILABLE: - import torchvision - import torchvision.transforms.functional as FT + from torchvision.transforms.functional import to_tensor class SemanticSegmentationInput(Input): @@ -104,9 +102,9 @@ def load_data( def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: filepath = sample[DataKeys.INPUT] - sample[DataKeys.INPUT] = FT.to_tensor(image_loader(filepath)) + sample[DataKeys.INPUT] = to_tensor(image_loader(filepath)) if DataKeys.TARGET in sample: - sample[DataKeys.TARGET] = torchvision.io.read_image(sample[DataKeys.TARGET])[0] + sample[DataKeys.TARGET] = (to_tensor(image_loader(sample[DataKeys.TARGET])) * 255).long()[0] sample = super().load_sample(sample) sample[DataKeys.METADATA]["filepath"] = filepath return sample @@ -124,20 +122,17 @@ def load_data( files = os.listdir(folder) files.sort() if mask_folder is not None: - mask_files = os.listdir(mask_folder) - - all_files = set(files).intersection(set(mask_files)) - if len(all_files) != len(files) or len(all_files) != len(mask_files): - rank_zero_warn( - f"Found inconsistent files in input folder: {folder} and mask folder: {mask_folder}. Some files" - " have been dropped.", - UserWarning, + mask_files = {os.path.splitext(file)[0]: file for file in os.listdir(mask_folder)} + file_names = [os.path.splitext(file)[0] for file in files] + + if len(set(file_names) - mask_files.keys()) != 0: + raise ValueError( + f"Found inconsistent files in input folder: {folder} and mask folder: {mask_folder}. All input " + f"files must have a corresponding mask file with the same name." ) - files = [os.path.join(folder, file) for file in all_files] - mask_files = [os.path.join(mask_folder, file) for file in all_files] - files.sort() - mask_files.sort() + files = [os.path.join(folder, file) for file in files] + mask_files = [os.path.join(mask_folder, mask_files[file_name]) for file_name in file_names] return super().load_data(files, mask_files) return super().load_data([os.path.join(folder, file) for file in files]) @@ -172,6 +167,6 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: class SemanticSegmentationDeserializer(ImageDeserializer): def serve_load_sample(self, data: str) -> Dict[str, Any]: result = super().serve_load_sample(data) - result[DataKeys.INPUT] = FT.to_tensor(result[DataKeys.INPUT]) + result[DataKeys.INPUT] = to_tensor(result[DataKeys.INPUT]) result[DataKeys.METADATA] = {"size": result[DataKeys.INPUT].shape[-2:]} return result diff --git a/tests/image/segmentation/test_data.py b/tests/image/segmentation/test_data.py index 61398320e1..7f91145cf6 100644 --- a/tests/image/segmentation/test_data.py +++ b/tests/image/segmentation/test_data.py @@ -126,7 +126,69 @@ def test_from_folders(tmpdir): @staticmethod @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") - def test_from_folders_warning(tmpdir): + def test_from_folders_different_extensions(tmpdir): + tmp_dir = Path(tmpdir) + + # create random dummy data + + os.makedirs(str(tmp_dir / "images")) + os.makedirs(str(tmp_dir / "targets")) + + images = [ + str(tmp_dir / "images" / "img1.png"), + str(tmp_dir / "images" / "img2.png"), + str(tmp_dir / "images" / "img3.png"), + ] + + targets = [ + str(tmp_dir / "targets" / "img1.tiff"), + str(tmp_dir / "targets" / "img2.tiff"), + str(tmp_dir / "targets" / "img3.tiff"), + ] + + num_classes: int = 2 + img_size: Tuple[int, int] = (128, 128) + create_random_data(images, targets, img_size, num_classes) + + # instantiate the data module + + dm = SemanticSegmentationData.from_folders( + train_folder=str(tmp_dir / "images"), + train_target_folder=str(tmp_dir / "targets"), + val_folder=str(tmp_dir / "images"), + val_target_folder=str(tmp_dir / "targets"), + test_folder=str(tmp_dir / "images"), + test_target_folder=str(tmp_dir / "targets"), + batch_size=2, + num_workers=0, + num_classes=num_classes, + ) + assert dm is not None + assert dm.train_dataloader() is not None + assert dm.val_dataloader() is not None + assert dm.test_dataloader() is not None + + # check training data + data = next(iter(dm.train_dataloader())) + imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] + assert imgs.shape == (2, 3, 128, 128) + assert labels.shape == (2, 128, 128) + + # check val data + data = next(iter(dm.val_dataloader())) + imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] + assert imgs.shape == (2, 3, 128, 128) + assert labels.shape == (2, 128, 128) + + # check test data + data = next(iter(dm.test_dataloader())) + imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] + assert imgs.shape == (2, 3, 128, 128) + assert labels.shape == (2, 128, 128) + + @staticmethod + @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") + def test_from_folders_error(tmpdir): tmp_dir = Path(tmpdir) # create random dummy data @@ -150,22 +212,14 @@ def test_from_folders_warning(tmpdir): # instantiate the data module - with pytest.warns(UserWarning, match="Found inconsistent files"): - dm = SemanticSegmentationData.from_folders( + with pytest.raises(ValueError, match="Found inconsistent files"): + SemanticSegmentationData.from_folders( train_folder=str(tmp_dir / "images"), train_target_folder=str(tmp_dir / "targets"), batch_size=1, num_workers=0, num_classes=num_classes, ) - assert dm is not None - assert dm.train_dataloader() is not None - - # check training data - data = next(iter(dm.train_dataloader())) - imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] - assert imgs.shape == (1, 3, 128, 128) - assert labels.shape == (1, 128, 128) @staticmethod @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")