From 47b3dd60656906bd9de038ffd00bca8845c46f2c Mon Sep 17 00:00:00 2001 From: Patrick Barton Date: Tue, 4 Jan 2022 21:42:57 +0100 Subject: [PATCH] Make Segmentation Tasks reproducible (#1094) Co-authored-by: Ethan Harris --- CHANGELOG.md | 2 ++ flash/image/segmentation/input.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 48639ade80..3dd39f40f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where passing `predict_data_frame` to `ImageClassificationData.from_data_frame` raised an error ([#1088](https://github.com/PyTorchLightning/lightning-flash/pull/1088)) +- Fixed a bug where segmentation files / masks were loaded with an inconsistent ordering ([#1094](https://github.com/PyTorchLightning/lightning-flash/pull/1094)) + ### Removed ## [0.6.0] - 2021-13-12 diff --git a/flash/image/segmentation/input.py b/flash/image/segmentation/input.py index 2a171e749e..1201597a0b 100644 --- a/flash/image/segmentation/input.py +++ b/flash/image/segmentation/input.py @@ -124,6 +124,7 @@ def load_data( ) -> List[Dict[str, Any]]: self.load_labels_map(num_classes, labels_map) files = os.listdir(folder) + files.sort() if mask_folder is not None: mask_files = os.listdir(mask_folder) @@ -137,6 +138,8 @@ def load_data( files = [os.path.join(folder, file) for file in all_files] mask_files = [os.path.join(mask_folder, file) for file in all_files] + files.sort() + mask_files.sort() return super().load_data(files, mask_files) return super().load_data(files)