Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding CDDB dataset #442

Merged
merged 4 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 70 additions & 1 deletion src/renate/benchmark/datasets/vision_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from typing import List, Optional, Tuple, Union

import gdown
import pandas as pd
import torch
import torchvision
Expand All @@ -14,7 +15,13 @@
from renate.benchmark.datasets.base import DataIncrementalDataModule
from renate.data import ImageDataset
from renate.data.data_module import RenateDataModule
from renate.utils.file import download_and_unzip_file, download_file, download_folder_from_s3
from renate.utils.file import (
download_and_unzip_file,
download_file,
download_file_from_s3,
download_folder_from_s3,
extract_file,
)


class TinyImageNetDataModule(RenateDataModule):
Expand Down Expand Up @@ -402,3 +409,65 @@ def _get_filepaths_and_labels(self, split: str) -> Tuple[List[str], List[int]]:
data = list(df.path.apply(lambda x: os.path.join(path, x)))
labels = list(df.label)
return data, labels


class CDDBDataModule(DataIncrementalDataModule):
md5s = {
"CDDB.tar.zip": "823b6496270ba03019dbd6af60cbcb6b",
}

domains = ["gaugan", "biggan", "wild", "whichfaceisreal", "san"]
dataset_stats = {
"CDDB": dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
}
google_drive_id = "1NgB8ytBMFBFwyXJQvdVT_yek1EaaEHrg"

def __init__(
self,
data_path: Union[Path, str],
src_bucket: Optional[str] = None,
src_object_name: Optional[str] = None,
domain: str = "gaugan",
val_size: float = defaults.VALIDATION_SIZE,
seed: int = defaults.SEED,
):
assert domain in self.domains
super().__init__(
data_path=data_path,
data_id=domain.lower(),
src_bucket=src_bucket,
src_object_name=src_object_name,
val_size=val_size,
seed=seed,
)

def prepare_data(self) -> None:
"""Download DomainNet dataset for given domain."""
file_name = "CDDB.tar.zip"
self._dataset_name = ""
if not self._verify_file(file_name):
if self._src_bucket is None:
gdown.download(
output=self._data_path,
quiet=False,
url=f"https://drive.google.com/u/0/uc?id={self.google_drive_id}&export=download&confirm=pbef", # noqa: E501
)
else:
download_file_from_s3(
dst=os.path.join(self._data_path, file_name),
src_bucket=self._src_bucket,
src_object_name=self._src_object_name,
)
extract_file(data_path=self._data_path, file_name="CDDB.tar.zip", dataset_name="")
extract_file(data_path=self._data_path, file_name="CDDB.tar", dataset_name="")

def setup(self) -> None:
self._dataset_name = "CDDB" # we need this because zip+tar
train_path = self._get_filepaths_and_labels("train")
train_data = torchvision.datasets.ImageFolder(train_path)
self._train_data, self._val_data = self._split_train_val_data(train_data)
test_path = self._get_filepaths_and_labels("val")
self._test_data = torchvision.datasets.ImageFolder(test_path)

def _get_filepaths_and_labels(self, split: str) -> Tuple[List[str], List[int]]:
return os.path.join(self._data_path, self._dataset_name, self.data_id, split)
35 changes: 35 additions & 0 deletions src/renate/benchmark/experiment_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from renate.benchmark.datasets.nlp_datasets import HuggingFaceTextDataModule, MultiTextDataModule
from renate.benchmark.datasets.vision_datasets import (
CDDBDataModule,
CLEARDataModule,
DomainNetDataModule,
TorchVisionDataModule,
Expand Down Expand Up @@ -174,6 +175,14 @@ def get_data_module(
val_size=val_size,
seed=seed,
)
if dataset_name == "CDDB":
return CDDBDataModule(
data_path=data_path,
src_bucket=src_bucket,
src_object_name=src_object_name,
val_size=val_size,
seed=seed,
)
raise ValueError(f"Unknown dataset `{dataset_name}`.")


Expand Down Expand Up @@ -335,6 +344,11 @@ def _get_normalize_transform(dataset_name):
DomainNetDataModule.dataset_stats["all"]["mean"],
DomainNetDataModule.dataset_stats["all"]["std"],
)
if dataset_name == "CDDB":
return transforms.Normalize(
CDDBDataModule.dataset_stats["CDDB"]["mean"],
CDDBDataModule.dataset_stats["CDDB"]["std"],
)


def train_transform(dataset_name: str, model_name: Optional[str] = None) -> Optional[Callable]:
Expand Down Expand Up @@ -391,6 +405,17 @@ def train_transform(dataset_name: str, model_name: Optional[str] = None) -> Opti
_get_normalize_transform(dataset_name),
]
)
if dataset_name == "CDDB":
return transforms.Compose(
[
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=63 / 255),
transforms.ToTensor(),
_get_normalize_transform(dataset_name),
]
)

raise ValueError(f"Unknown dataset `{dataset_name}`.")


Expand Down Expand Up @@ -442,6 +467,16 @@ def test_transform(
_get_normalize_transform(dataset_name),
]
)
if dataset_name == "CDDB":
return transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
_get_normalize_transform(dataset_name),
]
)

raise ValueError(f"Unknown dataset `{dataset_name}`.")


Expand Down
15 changes: 11 additions & 4 deletions src/renate/utils/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import shutil
from pathlib import Path
from tarfile import TarFile
from typing import List, Optional, Tuple, Union
from urllib.parse import urlparse
from zipfile import ZipFile
Expand Down Expand Up @@ -287,9 +288,15 @@ def delete_file_from_s3(bucket: str, object_name: str) -> None:
s3_client.delete_object(Bucket=bucket, Key=str(object_name))


def unzip_file(dataset_name: str, data_path: Union[str, Path], file_name: str) -> None:
"""Extract .zip files into folder named with dataset name."""
with ZipFile(os.path.join(data_path, dataset_name, file_name)) as f:
def extract_file(dataset_name: str, data_path: Union[str, Path], file_name: str) -> None:
"""Extract .zip or .tar depending on the flag files into folder named with dataset name."""
if file_name.endswith(".zip"):
Extractor = ZipFile
elif file_name.endswith(".tar"):
Extractor = TarFile
else:
raise ValueError("Unknown compressed format extension. Only Zip/Tar supported.")
with Extractor(os.path.join(data_path, dataset_name, file_name)) as f:
f.extractall(os.path.join(data_path, dataset_name))


Expand Down Expand Up @@ -328,7 +335,7 @@ def download_and_unzip_file(
) -> None:
"""A helper function to download data .zips and uncompress them."""
download_file(dataset_name, data_path, src_bucket, src_object_name, url, file_name)
unzip_file(dataset_name, data_path, file_name)
extract_file(dataset_name, data_path, file_name)


def save_pandas_df_to_csv(df: pd.DataFrame, file_path: Union[str, Path]) -> pd.DataFrame:
Expand Down
18 changes: 16 additions & 2 deletions test/renate/benchmark/test_experimentation_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@

from renate.benchmark import experiment_config
from renate.benchmark.datasets.nlp_datasets import HuggingFaceTextDataModule, MultiTextDataModule
from renate.benchmark.datasets.vision_datasets import CLEARDataModule, TorchVisionDataModule
from renate.benchmark.datasets.vision_datasets import (
CDDBDataModule,
CLEARDataModule,
TorchVisionDataModule,
)
from renate.benchmark.experiment_config import (
data_module_fn,
get_data_module,
Expand Down Expand Up @@ -95,6 +99,7 @@ def test_model_fn_fails_for_unknown_model():
"label",
),
("MultiText", MultiTextDataModule, "distilbert-base-uncased", None, None),
("CDDB", CDDBDataModule, None, None, None),
),
)
def test_get_data_module(
Expand Down Expand Up @@ -237,6 +242,13 @@ def test_get_scenario_fails_for_unknown_scenario(tmpdir):
DataIncrementalScenario,
2,
),
(
"DataIncrementalScenario",
"CDDB",
{"data_ids": ("biggan", "wild")},
DataIncrementalScenario,
2,
),
),
ids=[
"class_incremental",
Expand All @@ -250,6 +262,7 @@ def test_get_scenario_fails_for_unknown_scenario(tmpdir):
"wild_time_image_all_tasks",
"domainnet_by data_id",
"domainnet by groupings",
"cddb by dataid",
],
)
@pytest.mark.parametrize("val_size", (0, 0.5), ids=["no_val", "val"])
Expand Down Expand Up @@ -282,7 +295,7 @@ def test_data_module_fn(
elif expected_scenario_class == DataIncrementalScenario:
if "pretrained_model_name_or_path" in scenario_kwargs:
assert scenario._data_module._tokenizer is not None
elif dataset_name not in ["CLEAR10", "CLEAR100", "DomainNet"]:
elif dataset_name not in ["CLEAR10", "CLEAR100", "DomainNet", "CDDB"]:
assert scenario._data_module._tokenizer is None
assert scenario._num_tasks == expected_num_tasks

Expand All @@ -296,6 +309,7 @@ def test_data_module_fn(
("CIFAR100", Compose, Normalize, "ResNet18CIFAR"),
("CLEAR10", Compose, Compose, "ResNet18"),
("DomainNet", Compose, Compose, "VisionTransformerB16"),
("CDDB", Compose, Compose, None),
("hfd-rotten_tomatoes", type(None), type(None), "HuggingFaceTransformer"),
("fmow", Compose, Compose, "ResNet18"),
("yearbook", ToTensor, ToTensor, "ResNet18CIFAR"),
Expand Down