Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Make Segmentation Tasks reproducible (#1094)
Browse files Browse the repository at this point in the history
Co-authored-by: Ethan Harris <ethanwharris@gmail.com>
  • Loading branch information
bartonp2 and ethanwharris authored Jan 4, 2022
1 parent 088bd18 commit 47b3dd6
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where passing `predict_data_frame` to `ImageClassificationData.from_data_frame` raised an error ([#1088](https://github.com/PyTorchLightning/lightning-flash/pull/1088))

- Fixed a bug where segmentation files / masks were loaded with an inconsistent ordering ([#1094](https://github.com/PyTorchLightning/lightning-flash/pull/1094))

### Removed

## [0.6.0] - 2021-13-12
Expand Down
3 changes: 3 additions & 0 deletions flash/image/segmentation/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def load_data(
) -> List[Dict[str, Any]]:
self.load_labels_map(num_classes, labels_map)
files = os.listdir(folder)
files.sort()
if mask_folder is not None:
mask_files = os.listdir(mask_folder)

Expand All @@ -137,6 +138,8 @@ def load_data(

files = [os.path.join(folder, file) for file in all_files]
mask_files = [os.path.join(mask_folder, file) for file in all_files]
files.sort()
mask_files.sort()
return super().load_data(files, mask_files)
return super().load_data(files)

Expand Down

0 comments on commit 47b3dd6

Please sign in to comment.