diff --git a/CHANGELOG.md b/CHANGELOG.md index 5cb78ee80b..48639ade80 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where passing the `val_split` to the `DataModule` would not have the desired effect ([#1079](https://github.com/PyTorchLightning/lightning-flash/pull/1079)) +- Fixed a bug where passing `predict_data_frame` to `ImageClassificationData.from_data_frame` raised an error ([#1088](https://github.com/PyTorchLightning/lightning-flash/pull/1088)) + ### Removed ## [0.6.0] - 2021-13-12 diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index 3bcf3611e1..0ab5395f71 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -217,7 +217,7 @@ def from_data_frame( train_data = (train_data_frame, input_field, target_fields, train_images_root, train_resolver) val_data = (val_data_frame, input_field, target_fields, val_images_root, val_resolver) test_data = (test_data_frame, input_field, target_fields, test_images_root, test_resolver) - predict_data = (predict_data_frame, input_field, predict_images_root, predict_resolver) + predict_data = (predict_data_frame, input_field, None, predict_images_root, predict_resolver) return cls( input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw), diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index 59bf13ccb5..6213f26708 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -17,6 +17,7 @@ from typing import Any, List, Tuple import numpy as np +import pandas as pd import pytest import torch import torch.nn as nn @@ -84,6 +85,58 @@ def test_from_filepaths_smoke(tmpdir): assert sorted(list(labels.numpy())) == [1, 2] +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_from_data_frame_smoke(tmpdir): + tmpdir = Path(tmpdir) + + df = pd.DataFrame( + {"file": ["train.png", "valid.png", "test.png"], "split": ["train", "valid", "test"], "target": [0, 1, 1]} + ) + + [_rand_image().save(tmpdir / row.file) for i, row in df.iterrows()] + + img_data = ImageClassificationData.from_data_frame( + "file", + "target", + train_images_root=str(tmpdir), + val_images_root=str(tmpdir), + test_images_root=str(tmpdir), + train_data_frame=df[df.split == "train"], + val_data_frame=df[df.split == "valid"], + test_data_frame=df[df.split == "test"], + predict_images_root=str(tmpdir), + batch_size=1, + predict_data_frame=df, + ) + + assert img_data.train_dataloader() is not None + assert img_data.val_dataloader() is not None + assert img_data.test_dataloader() is not None + assert img_data.predict_dataloader() is not None + + data = next(iter(img_data.train_dataloader())) + imgs, labels = data["input"], data["target"] + assert imgs.shape == (1, 3, 196, 196) + assert labels.shape == (1,) + assert sorted(list(labels.numpy())) == [0] + + data = next(iter(img_data.val_dataloader())) + imgs, labels = data["input"], data["target"] + assert imgs.shape == (1, 3, 196, 196) + assert labels.shape == (1,) + assert sorted(list(labels.numpy())) == [1] + + data = next(iter(img_data.test_dataloader())) + imgs, labels = data["input"], data["target"] + assert imgs.shape == (1, 3, 196, 196) + assert labels.shape == (1,) + assert sorted(list(labels.numpy())) == [1] + + data = next(iter(img_data.predict_dataloader())) + imgs = data["input"] + assert imgs.shape == (1, 3, 196, 196) + + @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_from_filepaths_list_image_paths(tmpdir): tmpdir = Path(tmpdir)