diff --git a/configs/_base_/datasets/stanford_cars_bs8_448.py b/configs/_base_/datasets/stanford_cars_bs8_448.py new file mode 100644 index 00000000000..636b2e14be4 --- /dev/null +++ b/configs/_base_/datasets/stanford_cars_bs8_448.py @@ -0,0 +1,46 @@ +# dataset settings +dataset_type = 'StanfordCars' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', size=512), + dict(type='RandomCrop', size=448), + dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='ToTensor', keys=['gt_label']), + dict(type='Collect', keys=['img', 'gt_label']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', size=512), + dict(type='CenterCrop', crop_size=448), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) +] + +data_root = 'data/stanfordcars' +data = dict( + samples_per_gpu=8, + workers_per_gpu=2, + train=dict( + type=dataset_type, + data_prefix=data_root, + test_mode=False, + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_prefix=data_root, + test_mode=True, + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_prefix=data_root, + test_mode=True, + pipeline=test_pipeline)) + +evaluation = dict( + interval=1, metric='accuracy', + save_best='auto') # save the checkpoint with highest accuracy diff --git a/configs/_base_/schedules/stanford_cars_bs8.py b/configs/_base_/schedules/stanford_cars_bs8.py new file mode 100644 index 00000000000..dee252ec767 --- /dev/null +++ b/configs/_base_/schedules/stanford_cars_bs8.py @@ -0,0 +1,7 @@ +# optimizer +optimizer = dict( + type='SGD', lr=0.003, momentum=0.9, weight_decay=0.0005, nesterov=True) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[40, 70, 90]) +runner = dict(type='EpochBasedRunner', max_epochs=100) diff --git a/configs/resnet/README.md b/configs/resnet/README.md index f1d32effde7..d32fcd64e03 100644 --- a/configs/resnet/README.md +++ b/configs/resnet/README.md @@ -72,6 +72,12 @@ The pre-trained models on ImageNet-21k are used to fine-tune, and therefore don' | :-------: | :--------------------------------------------------: | :--------: | :-------: | :------: | :-------: | :------------------------------------------------: | :---------------------------------------------------: | | ResNet-50 | [ImageNet-21k-mill](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_3rdparty-mill_in21k_20220331-faac000b.pth) | 448x448 | 23.92 | 16.48 | 88.45 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb8_cub.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb8_cub_20220307-57840e60.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb8_cub_20220307-57840e60.log.json) | +### Stanford-Cars + +| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Config | Download | +| :-------: | :--------------------------------------------------: | :--------: | :-------: | :------: | :-------: | :------------------------------------------------: | :---------------------------------------------------: | +| ResNet-50 | [ImageNet-21k-mill](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_3rdparty-mill_in21k_20220331-faac000b.pth) | 448x448 | 23.92 | 16.48 | 92.82 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb8_cars.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb8_cars_20220812-9d85901a.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb8_cars_20220812-9d85901a.log.json) | + ## Citation ``` diff --git a/configs/resnet/metafile.yml b/configs/resnet/metafile.yml index 29aa84df37b..4be4bf9bf48 100644 --- a/configs/resnet/metafile.yml +++ b/configs/resnet/metafile.yml @@ -350,3 +350,16 @@ Models: Pretrain: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_3rdparty-mill_in21k_20220331-faac000b.pth Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb8_cub_20220307-57840e60.pth Config: configs/resnet/resnet50_8xb8_cub.py + - Name: resnet50_8xb8_cars + Metadata: + FLOPs: 16480000000 + Parameters: 23920000 + In Collection: ResNet + Results: + - Dataset: StanfordCars + Metrics: + Top 1 Accuracy: 92.82 + Task: Image Classification + Pretrain: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_3rdparty-mill_in21k_20220331-faac000b.pth + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb8_cars_20220812-9d85901a.pth + Config: configs/resnet/resnet50_8xb8_cars.py diff --git a/configs/resnet/resnet50_8xb8_cars.py b/configs/resnet/resnet50_8xb8_cars.py new file mode 100644 index 00000000000..2d2db45d08a --- /dev/null +++ b/configs/resnet/resnet50_8xb8_cars.py @@ -0,0 +1,19 @@ +_base_ = [ + '../_base_/models/resnet50.py', + '../_base_/datasets/stanford_cars_bs8_448.py', + '../_base_/schedules/stanford_cars_bs8.py', '../_base_/default_runtime.py' +] + +# use pre-train weight converted from https://github.com/Alibaba-MIIL/ImageNet21K # noqa +checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_3rdparty-mill_in21k_20220331-faac000b.pth' # noqa + +model = dict( + type='ImageClassifier', + backbone=dict( + init_cfg=dict( + type='Pretrained', checkpoint=checkpoint, prefix='backbone')), + head=dict(num_classes=196, )) + +log_config = dict(interval=50) +checkpoint_config = dict( + interval=1, max_keep_ckpts=3) # save last three checkpoints diff --git a/docs/en/api/datasets.rst b/docs/en/api/datasets.rst index 585f5586675..640ce1ad7d2 100644 --- a/docs/en/api/datasets.rst +++ b/docs/en/api/datasets.rst @@ -39,6 +39,11 @@ VOC .. autoclass:: VOC +StanfordCars Cars +----------------- + +.. autoclass:: StanfordCars + Base classes ------------ diff --git a/mmcls/datasets/__init__.py b/mmcls/datasets/__init__.py index c71dd50a201..095077e2321 100644 --- a/mmcls/datasets/__init__.py +++ b/mmcls/datasets/__init__.py @@ -12,6 +12,7 @@ from .mnist import MNIST, FashionMNIST from .multi_label import MultiLabelDataset from .samplers import DistributedSampler, RepeatAugSampler +from .stanford_cars import StanfordCars from .voc import VOC __all__ = [ @@ -19,5 +20,6 @@ 'VOC', 'MultiLabelDataset', 'build_dataloader', 'build_dataset', 'DistributedSampler', 'ConcatDataset', 'RepeatDataset', 'ClassBalancedDataset', 'DATASETS', 'PIPELINES', 'ImageNet21k', 'SAMPLERS', - 'build_sampler', 'RepeatAugSampler', 'KFoldDataset', 'CUB', 'CustomDataset' + 'build_sampler', 'RepeatAugSampler', 'KFoldDataset', 'CUB', + 'CustomDataset', 'StanfordCars' ] diff --git a/mmcls/datasets/stanford_cars.py b/mmcls/datasets/stanford_cars.py new file mode 100644 index 00000000000..df1f95126f6 --- /dev/null +++ b/mmcls/datasets/stanford_cars.py @@ -0,0 +1,210 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Optional + +import numpy as np + +from .base_dataset import BaseDataset +from .builder import DATASETS + + +@DATASETS.register_module() +class StanfordCars(BaseDataset): + """`Stanford Cars`_ Dataset. + + After downloading and decompression, the dataset + directory structure is as follows. + + Stanford Cars dataset directory:: + + Stanford Cars + ├── cars_train + │ ├── 00001.jpg + │ ├── 00002.jpg + │ └── ... + ├── cars_test + │ ├── 00001.jpg + │ ├── 00002.jpg + │ └── ... + └── devkit + ├── cars_meta.mat + ├── cars_train_annos.mat + ├── cars_test_annos.mat + ├── cars_test_annoswithlabels.mat + ├── eval_train.m + └── train_perfect_preds.txt + + .. _Stanford Cars: https://ai.stanford.edu/~jkrause/cars/car_dataset.html + + Args: + data_prefix (str): the prefix of data path + test_mode (bool): ``test_mode=True`` means in test phase. It determines + to use the training set or test set. + ann_file (str, optional): The annotation file. If is string, read + samples paths from the ann_file. If is None, read samples path + from cars_{train|test}_annos.mat file. Defaults to None. + """ # noqa: E501 + + CLASSES = [ + 'AM General Hummer SUV 2000', 'Acura RL Sedan 2012', + 'Acura TL Sedan 2012', 'Acura TL Type-S 2008', 'Acura TSX Sedan 2012', + 'Acura Integra Type R 2001', 'Acura ZDX Hatchback 2012', + 'Aston Martin V8 Vantage Convertible 2012', + 'Aston Martin V8 Vantage Coupe 2012', + 'Aston Martin Virage Convertible 2012', + 'Aston Martin Virage Coupe 2012', 'Audi RS 4 Convertible 2008', + 'Audi A5 Coupe 2012', 'Audi TTS Coupe 2012', 'Audi R8 Coupe 2012', + 'Audi V8 Sedan 1994', 'Audi 100 Sedan 1994', 'Audi 100 Wagon 1994', + 'Audi TT Hatchback 2011', 'Audi S6 Sedan 2011', + 'Audi S5 Convertible 2012', 'Audi S5 Coupe 2012', 'Audi S4 Sedan 2012', + 'Audi S4 Sedan 2007', 'Audi TT RS Coupe 2012', + 'BMW ActiveHybrid 5 Sedan 2012', 'BMW 1 Series Convertible 2012', + 'BMW 1 Series Coupe 2012', 'BMW 3 Series Sedan 2012', + 'BMW 3 Series Wagon 2012', 'BMW 6 Series Convertible 2007', + 'BMW X5 SUV 2007', 'BMW X6 SUV 2012', 'BMW M3 Coupe 2012', + 'BMW M5 Sedan 2010', 'BMW M6 Convertible 2010', 'BMW X3 SUV 2012', + 'BMW Z4 Convertible 2012', + 'Bentley Continental Supersports Conv. Convertible 2012', + 'Bentley Arnage Sedan 2009', 'Bentley Mulsanne Sedan 2011', + 'Bentley Continental GT Coupe 2012', + 'Bentley Continental GT Coupe 2007', + 'Bentley Continental Flying Spur Sedan 2007', + 'Bugatti Veyron 16.4 Convertible 2009', + 'Bugatti Veyron 16.4 Coupe 2009', 'Buick Regal GS 2012', + 'Buick Rainier SUV 2007', 'Buick Verano Sedan 2012', + 'Buick Enclave SUV 2012', 'Cadillac CTS-V Sedan 2012', + 'Cadillac SRX SUV 2012', 'Cadillac Escalade EXT Crew Cab 2007', + 'Chevrolet Silverado 1500 Hybrid Crew Cab 2012', + 'Chevrolet Corvette Convertible 2012', 'Chevrolet Corvette ZR1 2012', + 'Chevrolet Corvette Ron Fellows Edition Z06 2007', + 'Chevrolet Traverse SUV 2012', 'Chevrolet Camaro Convertible 2012', + 'Chevrolet HHR SS 2010', 'Chevrolet Impala Sedan 2007', + 'Chevrolet Tahoe Hybrid SUV 2012', 'Chevrolet Sonic Sedan 2012', + 'Chevrolet Express Cargo Van 2007', + 'Chevrolet Avalanche Crew Cab 2012', 'Chevrolet Cobalt SS 2010', + 'Chevrolet Malibu Hybrid Sedan 2010', 'Chevrolet TrailBlazer SS 2009', + 'Chevrolet Silverado 2500HD Regular Cab 2012', + 'Chevrolet Silverado 1500 Classic Extended Cab 2007', + 'Chevrolet Express Van 2007', 'Chevrolet Monte Carlo Coupe 2007', + 'Chevrolet Malibu Sedan 2007', + 'Chevrolet Silverado 1500 Extended Cab 2012', + 'Chevrolet Silverado 1500 Regular Cab 2012', 'Chrysler Aspen SUV 2009', + 'Chrysler Sebring Convertible 2010', + 'Chrysler Town and Country Minivan 2012', 'Chrysler 300 SRT-8 2010', + 'Chrysler Crossfire Convertible 2008', + 'Chrysler PT Cruiser Convertible 2008', 'Daewoo Nubira Wagon 2002', + 'Dodge Caliber Wagon 2012', 'Dodge Caliber Wagon 2007', + 'Dodge Caravan Minivan 1997', 'Dodge Ram Pickup 3500 Crew Cab 2010', + 'Dodge Ram Pickup 3500 Quad Cab 2009', 'Dodge Sprinter Cargo Van 2009', + 'Dodge Journey SUV 2012', 'Dodge Dakota Crew Cab 2010', + 'Dodge Dakota Club Cab 2007', 'Dodge Magnum Wagon 2008', + 'Dodge Challenger SRT8 2011', 'Dodge Durango SUV 2012', + 'Dodge Durango SUV 2007', 'Dodge Charger Sedan 2012', + 'Dodge Charger SRT-8 2009', 'Eagle Talon Hatchback 1998', + 'FIAT 500 Abarth 2012', 'FIAT 500 Convertible 2012', + 'Ferrari FF Coupe 2012', 'Ferrari California Convertible 2012', + 'Ferrari 458 Italia Convertible 2012', 'Ferrari 458 Italia Coupe 2012', + 'Fisker Karma Sedan 2012', 'Ford F-450 Super Duty Crew Cab 2012', + 'Ford Mustang Convertible 2007', 'Ford Freestar Minivan 2007', + 'Ford Expedition EL SUV 2009', 'Ford Edge SUV 2012', + 'Ford Ranger SuperCab 2011', 'Ford GT Coupe 2006', + 'Ford F-150 Regular Cab 2012', 'Ford F-150 Regular Cab 2007', + 'Ford Focus Sedan 2007', 'Ford E-Series Wagon Van 2012', + 'Ford Fiesta Sedan 2012', 'GMC Terrain SUV 2012', + 'GMC Savana Van 2012', 'GMC Yukon Hybrid SUV 2012', + 'GMC Acadia SUV 2012', 'GMC Canyon Extended Cab 2012', + 'Geo Metro Convertible 1993', 'HUMMER H3T Crew Cab 2010', + 'HUMMER H2 SUT Crew Cab 2009', 'Honda Odyssey Minivan 2012', + 'Honda Odyssey Minivan 2007', 'Honda Accord Coupe 2012', + 'Honda Accord Sedan 2012', 'Hyundai Veloster Hatchback 2012', + 'Hyundai Santa Fe SUV 2012', 'Hyundai Tucson SUV 2012', + 'Hyundai Veracruz SUV 2012', 'Hyundai Sonata Hybrid Sedan 2012', + 'Hyundai Elantra Sedan 2007', 'Hyundai Accent Sedan 2012', + 'Hyundai Genesis Sedan 2012', 'Hyundai Sonata Sedan 2012', + 'Hyundai Elantra Touring Hatchback 2012', 'Hyundai Azera Sedan 2012', + 'Infiniti G Coupe IPL 2012', 'Infiniti QX56 SUV 2011', + 'Isuzu Ascender SUV 2008', 'Jaguar XK XKR 2012', + 'Jeep Patriot SUV 2012', 'Jeep Wrangler SUV 2012', + 'Jeep Liberty SUV 2012', 'Jeep Grand Cherokee SUV 2012', + 'Jeep Compass SUV 2012', 'Lamborghini Reventon Coupe 2008', + 'Lamborghini Aventador Coupe 2012', + 'Lamborghini Gallardo LP 570-4 Superleggera 2012', + 'Lamborghini Diablo Coupe 2001', 'Land Rover Range Rover SUV 2012', + 'Land Rover LR2 SUV 2012', 'Lincoln Town Car Sedan 2011', + 'MINI Cooper Roadster Convertible 2012', + 'Maybach Landaulet Convertible 2012', 'Mazda Tribute SUV 2011', + 'McLaren MP4-12C Coupe 2012', + 'Mercedes-Benz 300-Class Convertible 1993', + 'Mercedes-Benz C-Class Sedan 2012', + 'Mercedes-Benz SL-Class Coupe 2009', + 'Mercedes-Benz E-Class Sedan 2012', 'Mercedes-Benz S-Class Sedan 2012', + 'Mercedes-Benz Sprinter Van 2012', 'Mitsubishi Lancer Sedan 2012', + 'Nissan Leaf Hatchback 2012', 'Nissan NV Passenger Van 2012', + 'Nissan Juke Hatchback 2012', 'Nissan 240SX Coupe 1998', + 'Plymouth Neon Coupe 1999', 'Porsche Panamera Sedan 2012', + 'Ram C/V Cargo Van Minivan 2012', + 'Rolls-Royce Phantom Drophead Coupe Convertible 2012', + 'Rolls-Royce Ghost Sedan 2012', 'Rolls-Royce Phantom Sedan 2012', + 'Scion xD Hatchback 2012', 'Spyker C8 Convertible 2009', + 'Spyker C8 Coupe 2009', 'Suzuki Aerio Sedan 2007', + 'Suzuki Kizashi Sedan 2012', 'Suzuki SX4 Hatchback 2012', + 'Suzuki SX4 Sedan 2012', 'Tesla Model S Sedan 2012', + 'Toyota Sequoia SUV 2012', 'Toyota Camry Sedan 2012', + 'Toyota Corolla Sedan 2012', 'Toyota 4Runner SUV 2012', + 'Volkswagen Golf Hatchback 2012', 'Volkswagen Golf Hatchback 1991', + 'Volkswagen Beetle Hatchback 2012', 'Volvo C30 Hatchback 2012', + 'Volvo 240 Sedan 1993', 'Volvo XC90 SUV 2007', + 'smart fortwo Convertible 2012' + ] + + def __init__(self, + data_prefix: str, + test_mode: bool, + ann_file: Optional[str] = None, + **kwargs): + if test_mode: + if ann_file is not None: + self.test_ann_file = ann_file + else: + self.test_ann_file = osp.join( + data_prefix, 'devkit/cars_test_annos_withlabels.mat') + data_prefix = osp.join(data_prefix, 'cars_test') + else: + if ann_file is not None: + self.train_ann_file = ann_file + else: + self.train_ann_file = osp.join(data_prefix, + 'devkit/cars_train_annos.mat') + data_prefix = osp.join(data_prefix, 'cars_train') + super(StanfordCars, self).__init__( + ann_file=ann_file, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + def load_annotations(self): + try: + import scipy.io as sio + except ImportError: + raise ImportError( + 'please run `pip install scipy` to install package `scipy`.') + + data_infos = [] + if self.test_mode: + data = sio.loadmat(self.test_ann_file) + else: + data = sio.loadmat(self.train_ann_file) + for img in data['annotations'][0]: + info = {'img_prefix': self.data_prefix} + # The organization of each record is as follows, + # 0: bbox_x1 of each image + # 1: bbox_y1 of each image + # 2: bbox_x2 of each image + # 3: bbox_y2 of each image + # 4: class_id, start from 0, so + # here we need to '- 1' to let them start from 0 + # 5: file name of each image + info['img_info'] = {'filename': img[5][0]} + info['gt_label'] = np.array(img[4][0][0] - 1, dtype=np.int64) + data_infos.append(info) + return data_infos diff --git a/requirements/optional.txt b/requirements/optional.txt index 8d449aae5ee..cc0228041b1 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -2,3 +2,4 @@ albumentations>=0.3.2 --no-binary qudida,albumentations colorama requests rich +scipy diff --git a/tests/test_data/test_datasets/test_common.py b/tests/test_data/test_datasets/test_common.py index b6bfe3bd341..5ec38184763 100644 --- a/tests/test_data/test_datasets/test_common.py +++ b/tests/test_data/test_datasets/test_common.py @@ -761,3 +761,151 @@ def test_load_annotations(self): @classmethod def tearDownClass(cls): cls.tmpdir.cleanup() + + +class TestStanfordCars(TestBaseDataset): + DATASET_TYPE = 'StanfordCars' + + def test_initialize(self): + dataset_class = DATASETS.get(self.DATASET_TYPE) + + with patch.object(dataset_class, 'load_annotations'): + # Test with test_mode=False, ann_file is None + cfg = {**self.DEFAULT_ARGS, 'test_mode': False, 'ann_file': None} + dataset = dataset_class(**cfg) + self.assertEqual(dataset.CLASSES, dataset_class.CLASSES) + self.assertFalse(dataset.test_mode) + self.assertIsNone(dataset.ann_file) + self.assertIsNotNone(dataset.train_ann_file) + + # Test with test_mode=False, ann_file is not None + cfg = { + **self.DEFAULT_ARGS, 'test_mode': False, + 'ann_file': 'train_ann_file.mat' + } + dataset = dataset_class(**cfg) + self.assertEqual(dataset.CLASSES, dataset_class.CLASSES) + self.assertFalse(dataset.test_mode) + self.assertIsNotNone(dataset.ann_file) + self.assertEqual(dataset.ann_file, 'train_ann_file.mat') + self.assertIsNotNone(dataset.train_ann_file) + + # Test with test_mode=True, ann_file is None + cfg = {**self.DEFAULT_ARGS, 'test_mode': True, 'ann_file': None} + dataset = dataset_class(**cfg) + self.assertEqual(dataset.CLASSES, dataset_class.CLASSES) + self.assertTrue(dataset.test_mode) + self.assertIsNone(dataset.ann_file) + self.assertIsNotNone(dataset.test_ann_file) + + # Test with test_mode=True, ann_file is not None + cfg = { + **self.DEFAULT_ARGS, 'test_mode': True, + 'ann_file': 'test_ann_file.mat' + } + dataset = dataset_class(**cfg) + self.assertEqual(dataset.CLASSES, dataset_class.CLASSES) + self.assertTrue(dataset.test_mode) + self.assertIsNotNone(dataset.ann_file) + self.assertEqual(dataset.ann_file, 'test_ann_file.mat') + self.assertIsNotNone(dataset.test_ann_file) + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + + tmpdir = tempfile.TemporaryDirectory() + cls.tmpdir = tmpdir + cls.data_prefix = tmpdir.name + cls.ann_file = None + devkit = osp.join(cls.data_prefix, 'devkit') + if not osp.exists(devkit): + os.mkdir(devkit) + cls.train_ann_file = osp.join(devkit, 'cars_train_annos.mat') + cls.test_ann_file = osp.join(devkit, 'cars_test_annos_withlabels.mat') + cls.DEFAULT_ARGS = dict( + data_prefix=cls.data_prefix, pipeline=[], test_mode=False) + + try: + import scipy.io as sio + except ImportError: + raise ImportError( + 'please run `pip install scipy` to install package `scipy`.') + + sio.savemat( + cls.train_ann_file, { + 'annotations': [( + (np.array([1]), np.array([10]), np.array( + [20]), np.array([50]), 15, np.array(['001.jpg'])), + (np.array([2]), np.array([15]), np.array( + [240]), np.array([250]), 15, np.array(['002.jpg'])), + (np.array([89]), np.array([150]), np.array( + [278]), np.array([388]), 150, np.array(['012.jpg'])), + )] + }) + + sio.savemat( + cls.test_ann_file, { + 'annotations': + [((np.array([89]), np.array([150]), np.array( + [278]), np.array([388]), 150, np.array(['025.jpg'])), + (np.array([155]), np.array([10]), np.array( + [200]), np.array([233]), 0, np.array(['111.jpg'])), + (np.array([25]), np.array([115]), np.array( + [240]), np.array([360]), 15, np.array(['265.jpg'])))] + }) + + def test_load_annotations(self): + dataset_class = DATASETS.get(self.DATASET_TYPE) + + # Test with test_mode=False and ann_file=None + dataset = dataset_class(**self.DEFAULT_ARGS) + self.assertEqual(len(dataset), 3) + self.assertEqual(dataset.CLASSES, dataset_class.CLASSES) + + data_info = dataset[0] + np.testing.assert_equal(data_info['img_prefix'], + osp.join(self.data_prefix, 'cars_train')) + np.testing.assert_equal(data_info['img_info'], {'filename': '001.jpg'}) + np.testing.assert_equal(data_info['gt_label'], 15 - 1) + + # Test with test_mode=True and ann_file=None + cfg = {**self.DEFAULT_ARGS, 'test_mode': True} + dataset = dataset_class(**cfg) + self.assertEqual(len(dataset), 3) + + data_info = dataset[0] + np.testing.assert_equal(data_info['img_prefix'], + osp.join(self.data_prefix, 'cars_test')) + np.testing.assert_equal(data_info['img_info'], {'filename': '025.jpg'}) + np.testing.assert_equal(data_info['gt_label'], 150 - 1) + + # Test with test_mode=False, ann_file is not None + cfg = { + **self.DEFAULT_ARGS, 'test_mode': False, + 'ann_file': self.train_ann_file + } + dataset = dataset_class(**cfg) + data_info = dataset[0] + np.testing.assert_equal(data_info['img_prefix'], + osp.join(self.data_prefix, 'cars_train')) + np.testing.assert_equal(data_info['img_info'], {'filename': '001.jpg'}) + np.testing.assert_equal(data_info['gt_label'], 15 - 1) + + # Test with test_mode=True, ann_file is not None + cfg = { + **self.DEFAULT_ARGS, 'test_mode': True, + 'ann_file': self.test_ann_file + } + dataset = dataset_class(**cfg) + self.assertEqual(len(dataset), 3) + + data_info = dataset[0] + np.testing.assert_equal(data_info['img_prefix'], + osp.join(self.data_prefix, 'cars_test')) + np.testing.assert_equal(data_info['img_info'], {'filename': '025.jpg'}) + np.testing.assert_equal(data_info['gt_label'], 150 - 1) + + @classmethod + def tearDownClass(cls): + cls.tmpdir.cleanup()