Skip to content

Commit

Permalink
Improve image extension handling, add methods to modify / get defaults.
Browse files Browse the repository at this point in the history
Fix #1335 fix #1274.
  • Loading branch information
rwightman committed Jul 7, 2022
1 parent 7d4b380 commit bfc0dcc
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 23 deletions.
5 changes: 3 additions & 2 deletions timm/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from .dataset_factory import create_dataset
from .loader import create_loader
from .mixup import Mixup, FastCollateMixup
from .parsers import create_parser
from .parsers import create_parser,\
get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
from .real_labels import RealLabelsImagenet
from .transforms import *
from .transforms_factory import create_transform
from .transforms_factory import create_transform
1 change: 1 addition & 0 deletions timm/data/parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .parser_factory import create_parser
from .img_extensions import *
1 change: 0 additions & 1 deletion timm/data/parsers/constants.py

This file was deleted.

50 changes: 50 additions & 0 deletions timm/data/parsers/img_extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from copy import deepcopy

__all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions']


IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use
_IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync


def _set_extensions(extensions):
global IMG_EXTENSIONS
global _IMG_EXTENSIONS_SET
dedupe = set() # NOTE de-duping tuple while keeping original order
IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x))
_IMG_EXTENSIONS_SET = set(extensions)


def _valid_extension(x: str):
return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.')


def is_img_extension(ext):
return ext in _IMG_EXTENSIONS_SET


def get_img_extensions(as_set=False):
return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS)


def set_img_extensions(extensions):
assert len(extensions)
for x in extensions:
assert _valid_extension(x)
_set_extensions(extensions)


def add_img_extensions(ext):
if not isinstance(ext, (list, tuple, set)):
ext = (ext,)
for x in ext:
assert _valid_extension(x)
extensions = IMG_EXTENSIONS + tuple(ext)
_set_extensions(extensions)


def del_img_extensions(ext):
if not isinstance(ext, (list, tuple, set)):
ext = (ext,)
extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext)
_set_extensions(extensions)
1 change: 0 additions & 1 deletion timm/data/parsers/parser_factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

from .parser_image_folder import ParserImageFolder
from .parser_image_tar import ParserImageTar
from .parser_image_in_tar import ParserImageInTar


Expand Down
29 changes: 25 additions & 4 deletions timm/data/parsers/parser_image_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,35 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import os
from typing import Dict, List, Optional, Set, Tuple, Union

from timm.utils.misc import natural_key

from .parser import Parser
from .class_map import load_class_map
from .constants import IMG_EXTENSIONS
from .img_extensions import get_img_extensions
from .parser import Parser


def find_images_and_targets(
folder: str,
types: Optional[Union[List, Tuple, Set]] = None,
class_to_idx: Optional[Dict] = None,
leaf_name_only: bool = True,
sort: bool = True
):
""" Walk folder recursively to discover images and map them to classes by folder names.
Args:
folder: root of folder to recrusively search
types: types (file extensions) to search for in path
class_to_idx: specify mapping for class (folder name) to class index if set
leaf_name_only: use only leaf-name of folder walk for class names
sort: re-sort found images by name (for consistent ordering)
def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
Returns:
A list of image and target tuples, class_to_idx mapping
"""
types = get_img_extensions(as_set=True) if not types else set(types)
labels = []
filenames = []
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
Expand Down Expand Up @@ -51,7 +71,8 @@ def __init__(
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
if len(self.samples) == 0:
raise RuntimeError(
f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
f'Found 0 images in subfolders of {root}. '
f'Supported image extensions are {", ".join(get_img_extensions())}')

def __getitem__(self, index):
path, target = self.samples[index]
Expand Down
29 changes: 18 additions & 11 deletions timm/data/parsers/parser_image_in_tar.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,20 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
import os
import tarfile
import pickle
import logging
import numpy as np
import tarfile
from glob import glob
from typing import List, Dict
from typing import List, Tuple, Dict, Set, Optional, Union

import numpy as np

from timm.utils.misc import natural_key

from .parser import Parser
from .class_map import load_class_map
from .constants import IMG_EXTENSIONS

from .img_extensions import get_img_extensions
from .parser import Parser

_logger = logging.getLogger(__name__)
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
Expand All @@ -39,7 +39,7 @@ def reset(self):
self.tf = None


def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTENSIONS):
def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions: Set[str]):
sample_count = 0
for i, ti in enumerate(tf):
if not ti.isfile():
Expand All @@ -60,7 +60,14 @@ def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTE
return sample_count


def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extensions=IMG_EXTENSIONS, sort=True):
def extract_tarinfos(
root,
class_name_to_idx: Optional[Dict] = None,
cache_tarinfo: Optional[bool] = None,
extensions: Optional[Union[List, Tuple, Set]] = None,
sort: bool = True
):
extensions = get_img_extensions(as_set=True) if not extensions else set(extensions)
root_is_tar = False
if os.path.isfile(root):
assert os.path.splitext(root)[-1].lower() == '.tar'
Expand Down Expand Up @@ -176,8 +183,8 @@ def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None):
self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos(
self.root,
class_name_to_idx=class_name_to_idx,
cache_tarinfo=cache_tarinfo,
extensions=IMG_EXTENSIONS)
cache_tarinfo=cache_tarinfo
)
self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
if len(tarfiles) == 1 and tarfiles[0][0] is None:
self.root_is_tar = True
Expand Down
10 changes: 6 additions & 4 deletions timm/data/parsers/parser_image_tar.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import os
import tarfile

from .parser import Parser
from .class_map import load_class_map
from .constants import IMG_EXTENSIONS
from timm.utils.misc import natural_key

from .class_map import load_class_map
from .img_extensions import get_img_extensions
from .parser import Parser


def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
extensions = get_img_extensions(as_set=True)
files = []
labels = []
for ti in tarfile.getmembers():
Expand All @@ -23,7 +25,7 @@ def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
dirname, basename = os.path.split(ti.path)
label = os.path.basename(dirname)
ext = os.path.splitext(basename)[1]
if ext.lower() in IMG_EXTENSIONS:
if ext.lower() in extensions:
files.append(ti)
labels.append(label)
if class_to_idx is None:
Expand Down

0 comments on commit bfc0dcc

Please sign in to comment.