From 3cba3180522b3e26a0908a678a082ac492a33898 Mon Sep 17 00:00:00 2001 From: Prabhu Teja Date: Mon, 2 Oct 2023 17:50:08 +0200 Subject: [PATCH 1/4] Adding CDDB dataset --- .../benchmark/datasets/vision_datasets.py | 71 ++++++++++++++++++- src/renate/benchmark/experiment_config.py | 35 +++++++++ src/renate/utils/file.py | 13 ++-- 3 files changed, 114 insertions(+), 5 deletions(-) diff --git a/src/renate/benchmark/datasets/vision_datasets.py b/src/renate/benchmark/datasets/vision_datasets.py index 7368ae33..08e0d9ce 100644 --- a/src/renate/benchmark/datasets/vision_datasets.py +++ b/src/renate/benchmark/datasets/vision_datasets.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import List, Optional, Tuple, Union +import gdown import pandas as pd import torch import torchvision @@ -14,7 +15,13 @@ from renate.benchmark.datasets.base import DataIncrementalDataModule from renate.data import ImageDataset from renate.data.data_module import RenateDataModule -from renate.utils.file import download_and_unzip_file, download_file, download_folder_from_s3 +from renate.utils.file import ( + download_and_unzip_file, + download_file, + download_file_from_s3, + download_folder_from_s3, + extract_file, +) class TinyImageNetDataModule(RenateDataModule): @@ -402,3 +409,65 @@ def _get_filepaths_and_labels(self, split: str) -> Tuple[List[str], List[int]]: data = list(df.path.apply(lambda x: os.path.join(path, x))) labels = list(df.label) return data, labels + + +class CDDBDataModule(DataIncrementalDataModule): + md5s = { + "CDDB.tar.zip": "823b6496270ba03019dbd6af60cbcb6b", + } + + domains = ["gaugan", "biggan", "wild", "whichfaceisreal", "san"] + dataset_stats = { + "CDDB": dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + } + google_drive_id = "1NgB8ytBMFBFwyXJQvdVT_yek1EaaEHrg" + + def __init__( + self, + data_path: Union[Path, str], + src_bucket: Optional[str] = None, + src_object_name: Optional[str] = None, + domain: str = "gaugan", + val_size: float = defaults.VALIDATION_SIZE, + seed: int = defaults.SEED, + ): + assert domain in self.domains + super().__init__( + data_path=data_path, + data_id=domain.lower(), + src_bucket=src_bucket, + src_object_name=src_object_name, + val_size=val_size, + seed=seed, + ) + + def prepare_data(self) -> None: + """Download DomainNet dataset for given domain.""" + file_name = "CDDB.tar.zip" + self._dataset_name = "" + if not self._verify_file(file_name): + if self._src_bucket is None: + gdown.download( + output=self._data_path, + quiet=False, + url=f"https://drive.google.com/u/0/uc?id={self.google_drive_id}&export=download&confirm=pbef", # noqa: E501 + ) + else: + download_file_from_s3( + dst=os.path.join(self._data_path, file_name), + src_bucket=self._src_bucket, + src_object_name=self._src_object_name, + ) + extract_file(data_path=self._data_path, file_name="CDDB.tar.zip", dataset_name="") + extract_file(data_path=self._data_path, file_name="CDDB.tar", dataset_name="") + + def setup(self) -> None: + self._dataset_name = "CDDB" # we need this because zip+tar + train_path = self._get_filepaths_and_labels("train") + train_data = torchvision.datasets.ImageFolder(train_path) + self._train_data, self._val_data = self._split_train_val_data(train_data) + test_path = self._get_filepaths_and_labels("val") + self._test_data = torchvision.datasets.ImageFolder(test_path) + + def _get_filepaths_and_labels(self, split: str) -> Tuple[List[str], List[int]]: + return os.path.join(self._data_path, self._dataset_name, self.data_id, split) diff --git a/src/renate/benchmark/experiment_config.py b/src/renate/benchmark/experiment_config.py index 16ba10b0..2d07a8aa 100644 --- a/src/renate/benchmark/experiment_config.py +++ b/src/renate/benchmark/experiment_config.py @@ -14,6 +14,7 @@ from renate.benchmark.datasets.nlp_datasets import HuggingFaceTextDataModule, MultiTextDataModule from renate.benchmark.datasets.vision_datasets import ( + CDDBDataModule, CLEARDataModule, DomainNetDataModule, TorchVisionDataModule, @@ -174,6 +175,14 @@ def get_data_module( val_size=val_size, seed=seed, ) + if dataset_name == "CDDB": + return CDDBDataModule( + data_path=data_path, + src_bucket=src_bucket, + src_object_name=src_object_name, + val_size=val_size, + seed=seed, + ) raise ValueError(f"Unknown dataset `{dataset_name}`.") @@ -335,6 +344,11 @@ def _get_normalize_transform(dataset_name): DomainNetDataModule.dataset_stats["all"]["mean"], DomainNetDataModule.dataset_stats["all"]["std"], ) + if dataset_name == "CDDB": + return transforms.Normalize( + CDDBDataModule.dataset_stats["CDDB"]["mean"], + CDDBDataModule.dataset_stats["CDDB"]["std"], + ) def train_transform(dataset_name: str, model_name: Optional[str] = None) -> Optional[Callable]: @@ -391,6 +405,17 @@ def train_transform(dataset_name: str, model_name: Optional[str] = None) -> Opti _get_normalize_transform(dataset_name), ] ) + if dataset_name == "CDDB": + return transforms.Compose( + [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(brightness=63 / 255), + transforms.ToTensor(), + _get_normalize_transform(dataset_name), + ] + ) + raise ValueError(f"Unknown dataset `{dataset_name}`.") @@ -442,6 +467,16 @@ def test_transform( _get_normalize_transform(dataset_name), ] ) + if dataset_name == "CDDB": + return transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + _get_normalize_transform(dataset_name), + ] + ) + raise ValueError(f"Unknown dataset `{dataset_name}`.") diff --git a/src/renate/utils/file.py b/src/renate/utils/file.py index f7db7a68..14f4defd 100644 --- a/src/renate/utils/file.py +++ b/src/renate/utils/file.py @@ -4,6 +4,7 @@ import os import shutil from pathlib import Path +from tarfile import TarFile from typing import List, Optional, Tuple, Union from urllib.parse import urlparse from zipfile import ZipFile @@ -287,9 +288,13 @@ def delete_file_from_s3(bucket: str, object_name: str) -> None: s3_client.delete_object(Bucket=bucket, Key=str(object_name)) -def unzip_file(dataset_name: str, data_path: Union[str, Path], file_name: str) -> None: - """Extract .zip files into folder named with dataset name.""" - with ZipFile(os.path.join(data_path, dataset_name, file_name)) as f: +def extract_file(dataset_name: str, data_path: Union[str, Path], file_name: str) -> None: + """Extract .zip or .tar depending on the flag files into folder named with dataset name.""" + if file_name.endswith("zip"): + Extractor = ZipFile + elif file_name.endswith(".tar"): + Extractor = TarFile + with Extractor(os.path.join(data_path, dataset_name, file_name)) as f: f.extractall(os.path.join(data_path, dataset_name)) @@ -328,7 +333,7 @@ def download_and_unzip_file( ) -> None: """A helper function to download data .zips and uncompress them.""" download_file(dataset_name, data_path, src_bucket, src_object_name, url, file_name) - unzip_file(dataset_name, data_path, file_name) + extract_file(dataset_name, data_path, file_name) def save_pandas_df_to_csv(df: pd.DataFrame, file_path: Union[str, Path]) -> pd.DataFrame: From 33b1510a33bd198bcaa0068f6b934ef581c559c1 Mon Sep 17 00:00:00 2001 From: Prabhu Teja Date: Mon, 2 Oct 2023 18:05:40 +0200 Subject: [PATCH 2/4] adding tests --- .../benchmark/test_experimentation_config.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/test/renate/benchmark/test_experimentation_config.py b/test/renate/benchmark/test_experimentation_config.py index 92a84b49..8adbaec7 100644 --- a/test/renate/benchmark/test_experimentation_config.py +++ b/test/renate/benchmark/test_experimentation_config.py @@ -9,7 +9,11 @@ from renate.benchmark import experiment_config from renate.benchmark.datasets.nlp_datasets import HuggingFaceTextDataModule, MultiTextDataModule -from renate.benchmark.datasets.vision_datasets import CLEARDataModule, TorchVisionDataModule +from renate.benchmark.datasets.vision_datasets import ( + CDDBDataModule, + CLEARDataModule, + TorchVisionDataModule, +) from renate.benchmark.experiment_config import ( data_module_fn, get_data_module, @@ -95,6 +99,7 @@ def test_model_fn_fails_for_unknown_model(): "label", ), ("MultiText", MultiTextDataModule, "distilbert-base-uncased", None, None), + ("CDDB", CDDBDataModule, None, None, None), ), ) def test_get_data_module( @@ -237,6 +242,13 @@ def test_get_scenario_fails_for_unknown_scenario(tmpdir): DataIncrementalScenario, 2, ), + ( + "DataIncrementalScenario", + "CDDB", + {"data_ids": ("biggan", "wild")}, + DataIncrementalScenario, + 2, + ), ), ids=[ "class_incremental", @@ -250,6 +262,7 @@ def test_get_scenario_fails_for_unknown_scenario(tmpdir): "wild_time_image_all_tasks", "domainnet_by data_id", "domainnet by groupings", + "cddb by dataid", ], ) @pytest.mark.parametrize("val_size", (0, 0.5), ids=["no_val", "val"]) @@ -282,7 +295,7 @@ def test_data_module_fn( elif expected_scenario_class == DataIncrementalScenario: if "pretrained_model_name_or_path" in scenario_kwargs: assert scenario._data_module._tokenizer is not None - elif dataset_name not in ["CLEAR10", "CLEAR100", "DomainNet"]: + elif dataset_name not in ["CLEAR10", "CLEAR100", "DomainNet", "CDDB"]: assert scenario._data_module._tokenizer is None assert scenario._num_tasks == expected_num_tasks @@ -296,6 +309,7 @@ def test_data_module_fn( ("CIFAR100", Compose, Normalize, "ResNet18CIFAR"), ("CLEAR10", Compose, Compose, "ResNet18"), ("DomainNet", Compose, Compose, "VisionTransformerB16"), + ("CDDB", Compose, Compose, None), ("hfd-rotten_tomatoes", type(None), type(None), "HuggingFaceTransformer"), ("fmow", Compose, Compose, "ResNet18"), ("yearbook", ToTensor, ToTensor, "ResNet18CIFAR"), From bab1e4854ae0b0d7e3e143860f99fd60c00fd2e1 Mon Sep 17 00:00:00 2001 From: Prabhu Teja Date: Wed, 4 Oct 2023 16:29:10 +0200 Subject: [PATCH 3/4] check filename with .zip --- src/renate/utils/file.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/renate/utils/file.py b/src/renate/utils/file.py index 14f4defd..42b532b8 100644 --- a/src/renate/utils/file.py +++ b/src/renate/utils/file.py @@ -290,7 +290,7 @@ def delete_file_from_s3(bucket: str, object_name: str) -> None: def extract_file(dataset_name: str, data_path: Union[str, Path], file_name: str) -> None: """Extract .zip or .tar depending on the flag files into folder named with dataset name.""" - if file_name.endswith("zip"): + if file_name.endswith(".zip"): Extractor = ZipFile elif file_name.endswith(".tar"): Extractor = TarFile From 7eca5715a36dd3b74fc9027c03ebd4c02086c3fc Mon Sep 17 00:00:00 2001 From: Prabhu Teja Date: Wed, 4 Oct 2023 16:30:53 +0200 Subject: [PATCH 4/4] raising value error --- src/renate/utils/file.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/renate/utils/file.py b/src/renate/utils/file.py index 42b532b8..98e595df 100644 --- a/src/renate/utils/file.py +++ b/src/renate/utils/file.py @@ -294,6 +294,8 @@ def extract_file(dataset_name: str, data_path: Union[str, Path], file_name: str) Extractor = ZipFile elif file_name.endswith(".tar"): Extractor = TarFile + else: + raise ValueError("Unknown compressed format extension. Only Zip/Tar supported.") with Extractor(os.path.join(data_path, dataset_name, file_name)) as f: f.extractall(os.path.join(data_path, dataset_name))