Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Bugfix: _show_images_and_labels - flatten axes to allow plotting fo…
Browse files Browse the repository at this point in the history
…r subplots (#1339)

Co-authored-by: Ethan Harris <ethanwharris@gmail.com>
  • Loading branch information
krshrimali and ethanwharris authored May 11, 2022
1 parent b112837 commit 61e1a2d
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed image classification data `show_train_batch` for subplots with rows > 1. ([#1339](https://github.com/PyTorchLightning/lightning-flash/pull/1315))

- Fixed support for all the versions (including the latest and older) of `baal`. ([#1315](https://github.com/PyTorchLightning/lightning-flash/pull/1315))

- Fixed a bug where a loaded `TabularClassifier` or `TabularRegressor` checkpoint could not be served ([#1324](https://github.com/PyTorchLightning/lightning-flash/pull/1324))
Expand Down
3 changes: 2 additions & 1 deletion flash/image/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,8 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str)
fig.suptitle(title)

if not isinstance(axs, np.ndarray):
axs = [axs]
axs = np.array(axs)
axs = axs.flatten()

for i, ax in enumerate(axs):
# unpack images and labels
Expand Down
68 changes: 68 additions & 0 deletions tests/image/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,74 @@ def test_from_filepaths_visualise(tmpdir):
dm.show_train_batch(["per_sample_transform", "per_batch_transform"])


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.")
def test_from_filepaths_visualise_subplots_exceding_max_cols(tmpdir):
tmpdir = Path(tmpdir)

(tmpdir / "e").mkdir()
_rand_image().save(tmpdir / "e_1.png")

train_images = [
str(tmpdir / "e_1.png"),
] * 8

dm = ImageClassificationData.from_files(
train_files=train_images,
train_targets=[0, 3, 6, 9, 8, 9, 1, 2],
val_files=train_images,
val_targets=[1, 4, 7, 8, 9, 8, 7, 1],
test_files=train_images,
test_targets=[2, 5, 8, 9, 7, 1, 2, 3],
batch_size=8,
num_workers=0,
)

# disable visualisation for testing
assert dm.data_fetcher.block_viz_window is True
dm.set_block_viz_window(False)
assert dm.data_fetcher.block_viz_window is False

# call show functions
# dm.show_train_batch()
dm.show_train_batch("per_sample_transform")
dm.show_train_batch(["per_sample_transform", "per_batch_transform"])


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.")
def test_from_filepaths_visualise_subplots_single_image(tmpdir):
tmpdir = Path(tmpdir)

(tmpdir / "e").mkdir()
_rand_image().save(tmpdir / "e_1.png")

train_images = [
str(tmpdir / "e_1.png"),
]

dm = ImageClassificationData.from_files(
train_files=train_images,
train_targets=[0],
val_files=train_images,
val_targets=[1],
test_files=train_images,
test_targets=[2],
batch_size=1,
num_workers=0,
)

# disable visualisation for testing
assert dm.data_fetcher.block_viz_window is True
dm.set_block_viz_window(False)
assert dm.data_fetcher.block_viz_window is False

# call show functions
# dm.show_train_batch()
dm.show_train_batch("per_sample_transform")
dm.show_train_batch(["per_sample_transform", "per_batch_transform"])


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.")
def test_from_filepaths_visualise_multilabel(tmpdir):
Expand Down

0 comments on commit 61e1a2d

Please sign in to comment.