Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Cat loader2 #117

Merged
merged 8 commits into from
Feb 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions flash/data/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from typing import Dict, List, Union
from typing import Any, Dict, List, Union

import pandas as pd


def labels_from_categorical_csv(csv: str, index_col: str, return_dict: bool = True) -> Union[Dict, List]:
def labels_from_categorical_csv(
csv: str,
index_col: str,
feature_cols: List,
return_dict: bool = True,
index_col_collate_fn: Any = None
) -> Union[Dict, List]:
"""
Returns a dictionary with {index_col: label} for each entry in the csv.

Expand All @@ -17,10 +23,15 @@ def labels_from_categorical_csv(csv: str, index_col: str, return_dict: bool = Tr
df = pd.read_csv(csv)
# get names
names = df[index_col].to_list()
del df[index_col]

# apply colate fn to index_col
if index_col_collate_fn:
for i in range(len(names)):
names[i] = index_col_collate_fn(names[i])

# everything else is binary
labels = df.to_numpy().argmax(1).tolist()
feature_df = df[feature_cols]
labels = feature_df.to_numpy().argmax(1).tolist()

if return_dict:
labels = {name: label for name, label in zip(names, labels)}
Expand Down
34 changes: 25 additions & 9 deletions flash/vision/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def __getitem__(self, index: int) -> Tuple[Any, Optional[int]]:
img = self.transform(img)
label = None
if self.has_dict_labels:
name = os.path.basename(filename)
name = os.path.splitext(filename)[0]
name = os.path.basename(name)
label = self.labels[name]

elif self.has_labels:
Expand Down Expand Up @@ -256,6 +257,7 @@ def from_filepaths(
train_filepaths: Union[str, Optional[Sequence[Union[str, pathlib.Path]]]] = None,
train_labels: Optional[Sequence] = None,
train_transform: Optional[Callable] = _default_train_transforms,
valid_split: Union[None, float] = None,
valid_filepaths: Union[str, Optional[Sequence[Union[str, pathlib.Path]]]] = None,
valid_labels: Optional[Sequence] = None,
valid_transform: Optional[Callable] = _default_valid_transforms,
Expand All @@ -264,6 +266,7 @@ def from_filepaths(
loader: Callable = _pil_loader,
batch_size: int = 64,
num_workers: Optional[int] = None,
seed: int = 1234,
**kwargs
):
"""Creates a ImageClassificationData object from lists of image filepaths and labels
Expand All @@ -272,6 +275,7 @@ def from_filepaths(
train_filepaths: string or sequence of file paths for training dataset. Defaults to ``None``.
train_labels: sequence of labels for training dataset. Defaults to ``None``.
train_transform: transforms for training dataset. Defaults to ``None``.
valid_split: if not None, generates val split from train dataloader using this value.
valid_filepaths: string or sequence of file paths for validation dataset. Defaults to ``None``.
valid_labels: sequence of labels for validation dataset. Defaults to ``None``.
valid_transform: transforms for validation and testing dataset. Defaults to ``None``.
Expand All @@ -281,6 +285,7 @@ def from_filepaths(
batch_size: the batchsize to use for parallel loading. Defaults to ``64``.
num_workers: The number of workers to use for parallelized loading.
Defaults to ``None`` which equals the number of available CPU threads.
seed: Used for the train/val splits when valid_split is not None

Returns:
ImageClassificationData: The constructed data module.
Expand Down Expand Up @@ -319,14 +324,25 @@ def from_filepaths(
loader=loader,
transform=train_transform,
)
valid_ds = (
FilepathDataset(
filepaths=valid_filepaths,
labels=valid_labels,
loader=loader,
transform=valid_transform,
) if valid_filepaths is not None else None
)

if valid_split:
full_length = len(train_ds)
train_split = int((1.0 - valid_split) * full_length)
valid_split = full_length - train_split
train_ds, valid_ds = torch.utils.data.random_split(
train_ds,
[train_split, valid_split],
generator=torch.Generator().manual_seed(seed)
)
else:
valid_ds = (
FilepathDataset(
filepaths=valid_filepaths,
labels=valid_labels,
loader=loader,
transform=valid_transform,
) if valid_filepaths is not None else None
)

test_ds = (
FilepathDataset(
Expand Down
28 changes: 22 additions & 6 deletions tests/vision/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,27 +93,33 @@ def test_categorical_csv_labels(tmpdir):
train_csv = os.path.join(tmpdir, 'some_dataset', 'train.csv')
text_file = open(train_csv, 'w')
text_file.write(
'my_id, label_a, label_b, label_c\n"train_1.png", 0, 1, 0\n"train_2.png", 0, 0, 1\n"train_2.png", 1, 0, 0\n'
'my_id,label_a,label_b,label_c\n"train_1.png", 0, 1, 0\n"train_2.png", 0, 0, 1\n"train_2.png", 1, 0, 0\n'
)
text_file.close()

valid_csv = os.path.join(tmpdir, 'some_dataset', 'valid.csv')
text_file = open(valid_csv, 'w')
text_file.write(
'my_id, label_a, label_b, label_c\n"valid_1.png", 0, 1, 0\n"valid_2.png", 0, 0, 1\n"valid_3.png", 1, 0, 0\n'
'my_id,label_a,label_b,label_c\n"valid_1.png", 0, 1, 0\n"valid_2.png", 0, 0, 1\n"valid_3.png", 1, 0, 0\n'
)
text_file.close()

test_csv = os.path.join(tmpdir, 'some_dataset', 'test.csv')
text_file = open(test_csv, 'w')
text_file.write(
'my_id, label_a, label_b, label_c\n"test_1.png", 0, 1, 0\n"test_2.png", 0, 0, 1\n"test_3.png", 1, 0, 0\n'
'my_id,label_a,label_b,label_c\n"test_1.png", 0, 1, 0\n"test_2.png", 0, 0, 1\n"test_3.png", 1, 0, 0\n'
)
text_file.close()

train_labels = labels_from_categorical_csv(train_csv, 'my_id')
valid_labels = labels_from_categorical_csv(valid_csv, 'my_id')
test_labels = labels_from_categorical_csv(test_csv, 'my_id')
def index_col_collate_fn(x):
return os.path.splitext(x)[0]

train_labels = labels_from_categorical_csv(
train_csv, 'my_id', feature_cols=['label_a', 'label_b', 'label_c'], index_col_collate_fn=index_col_collate_fn)
valid_labels = labels_from_categorical_csv(
valid_csv, 'my_id', feature_cols=['label_a', 'label_b', 'label_c'], index_col_collate_fn=index_col_collate_fn)
test_labels = labels_from_categorical_csv(
test_csv, 'my_id', feature_cols=['label_a', 'label_b', 'label_c'], index_col_collate_fn=index_col_collate_fn)

data = ImageClassificationData.from_filepaths(
batch_size=2,
Expand All @@ -134,6 +140,16 @@ def test_categorical_csv_labels(tmpdir):
for (x, y) in data.test_dataloader():
assert len(x) == 2

data = ImageClassificationData.from_filepaths(
batch_size=2,
train_filepaths=os.path.join(tmpdir, 'some_dataset', 'train'),
train_labels=train_labels,
valid_split=0.5
)

for (x, y) in data.val_dataloader():
assert len(x) == 1


def test_from_folders(tmpdir):
train_dir = Path(tmpdir / "train")
Expand Down