Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
fix from_data_frame factory method with prediction df (#1088)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael Boesl <michael.boesl@continental-corporation.com>
Co-authored-by: Ethan Harris <ethanwharris@gmail.com>
  • Loading branch information
3 people authored Dec 30, 2021
1 parent 4d00c34 commit 1a599ef
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where passing the `val_split` to the `DataModule` would not have the desired effect ([#1079](https://github.com/PyTorchLightning/lightning-flash/pull/1079))

- Fixed a bug where passing `predict_data_frame` to `ImageClassificationData.from_data_frame` raised an error ([#1088](https://github.com/PyTorchLightning/lightning-flash/pull/1088))

### Removed

## [0.6.0] - 2021-13-12
Expand Down
2 changes: 1 addition & 1 deletion flash/image/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def from_data_frame(
train_data = (train_data_frame, input_field, target_fields, train_images_root, train_resolver)
val_data = (val_data_frame, input_field, target_fields, val_images_root, val_resolver)
test_data = (test_data_frame, input_field, target_fields, test_images_root, test_resolver)
predict_data = (predict_data_frame, input_field, predict_images_root, predict_resolver)
predict_data = (predict_data_frame, input_field, None, predict_images_root, predict_resolver)

return cls(
input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw),
Expand Down
53 changes: 53 additions & 0 deletions tests/image/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, List, Tuple

import numpy as np
import pandas as pd
import pytest
import torch
import torch.nn as nn
Expand Down Expand Up @@ -84,6 +85,58 @@ def test_from_filepaths_smoke(tmpdir):
assert sorted(list(labels.numpy())) == [1, 2]


@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_from_data_frame_smoke(tmpdir):
tmpdir = Path(tmpdir)

df = pd.DataFrame(
{"file": ["train.png", "valid.png", "test.png"], "split": ["train", "valid", "test"], "target": [0, 1, 1]}
)

[_rand_image().save(tmpdir / row.file) for i, row in df.iterrows()]

img_data = ImageClassificationData.from_data_frame(
"file",
"target",
train_images_root=str(tmpdir),
val_images_root=str(tmpdir),
test_images_root=str(tmpdir),
train_data_frame=df[df.split == "train"],
val_data_frame=df[df.split == "valid"],
test_data_frame=df[df.split == "test"],
predict_images_root=str(tmpdir),
batch_size=1,
predict_data_frame=df,
)

assert img_data.train_dataloader() is not None
assert img_data.val_dataloader() is not None
assert img_data.test_dataloader() is not None
assert img_data.predict_dataloader() is not None

data = next(iter(img_data.train_dataloader()))
imgs, labels = data["input"], data["target"]
assert imgs.shape == (1, 3, 196, 196)
assert labels.shape == (1,)
assert sorted(list(labels.numpy())) == [0]

data = next(iter(img_data.val_dataloader()))
imgs, labels = data["input"], data["target"]
assert imgs.shape == (1, 3, 196, 196)
assert labels.shape == (1,)
assert sorted(list(labels.numpy())) == [1]

data = next(iter(img_data.test_dataloader()))
imgs, labels = data["input"], data["target"]
assert imgs.shape == (1, 3, 196, 196)
assert labels.shape == (1,)
assert sorted(list(labels.numpy())) == [1]

data = next(iter(img_data.predict_dataloader()))
imgs = data["input"]
assert imgs.shape == (1, 3, 196, 196)


@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_from_filepaths_list_image_paths(tmpdir):
tmpdir = Path(tmpdir)
Expand Down

0 comments on commit 1a599ef

Please sign in to comment.