From 7eae52652d74a5d51461a85bb5d41f2da5417f12 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 31 Jul 2020 20:48:36 +0200 Subject: [PATCH] add typehints for torchvision.datasets.omniglot --- torchvision/datasets/omniglot.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py index dd861284884..167fc4e6d0e 100644 --- a/torchvision/datasets/omniglot.py +++ b/torchvision/datasets/omniglot.py @@ -1,6 +1,7 @@ from PIL import Image from os.path import join import os +from typing import Any, Callable, List, Optional, Tuple from .vision import VisionDataset from .utils import download_and_extract_archive, check_integrity, list_dir, list_files @@ -27,8 +28,14 @@ class Omniglot(VisionDataset): 'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811' } - def __init__(self, root, background=True, transform=None, target_transform=None, - download=False): + def __init__( + self, + root: str, + background: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: super(Omniglot, self).__init__(join(root, self.folder), transform=transform, target_transform=target_transform) self.background = background @@ -42,16 +49,16 @@ def __init__(self, root, background=True, transform=None, target_transform=None, self.target_folder = join(self.root, self._get_target_folder()) self._alphabets = list_dir(self.target_folder) - self._characters = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))] - for a in self._alphabets], []) + self._characters: List[str] = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))] + for a in self._alphabets], []) self._character_images = [[(image, idx) for image in list_files(join(self.target_folder, character), '.png')] for idx, character in enumerate(self._characters)] - self._flat_character_images = sum(self._character_images, []) + self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, []) - def __len__(self): + def __len__(self) -> int: return len(self._flat_character_images) - def __getitem__(self, index): + def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index @@ -71,13 +78,13 @@ def __getitem__(self, index): return image, character_class - def _check_integrity(self): + def _check_integrity(self) -> bool: zip_filename = self._get_target_folder() if not check_integrity(join(self.root, zip_filename + '.zip'), self.zips_md5[zip_filename]): return False return True - def download(self): + def download(self) -> None: if self._check_integrity(): print('Files already downloaded and verified') return @@ -87,5 +94,5 @@ def download(self): url = self.download_url_prefix + '/' + zip_filename download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename]) - def _get_target_folder(self): + def _get_target_folder(self) -> str: return 'images_background' if self.background else 'images_evaluation'