From edfd48271856d29f947cd4306bfb197a4bd7fcb0 Mon Sep 17 00:00:00 2001 From: Cathy0908 <30484308+Cathy0908@users.noreply.github.com> Date: Wed, 8 Feb 2023 14:09:14 +0800 Subject: [PATCH] add mobilenet itag config (#276) * add mobilenet itag config --- configs/classification/itag/mobilenetv2.py | 88 +++++++++++++ easycv/apis/export.py | 2 + .../classification/data_sources/image_list.py | 120 ++++++------------ easycv/predictors/classifier.py | 8 +- .../data_sources/test_cls_itag_datasource.py | 92 ++++++++++++++ tests/ut_config.py | 4 + tools/train.py | 4 +- 7 files changed, 233 insertions(+), 85 deletions(-) create mode 100644 configs/classification/itag/mobilenetv2.py create mode 100644 tests/datasets/classification/data_sources/test_cls_itag_datasource.py diff --git a/configs/classification/itag/mobilenetv2.py b/configs/classification/itag/mobilenetv2.py new file mode 100644 index 00000000..df3dc581 --- /dev/null +++ b/configs/classification/itag/mobilenetv2.py @@ -0,0 +1,88 @@ +_base_ = '../imagenet/common/classification_base.py' + +# oss_io_config = dict(ak_id='', # your oss ak id +# ak_secret='', # your oss ak secret +# hosts='', # your oss hosts +# buckets=[]) # your oss bucket name + +# Ensure the CLASSES definition is in one line, for adapt to its replacement by user_config_params. +# yapf:disable +CLASSES = ['label1', 'label2', 'label3'] # replace with your true lables of itag manifest file +num_classes = 3 +# model settings +model = dict( + type='Classification', + backbone=dict(type='MobileNetV2'), + head=dict( + type='ClsHead', + with_avg_pool=True, + in_channels=1280, + loss_config=dict( + type='CrossEntropyLossWithLabelSmooth', + label_smooth=0, + ), + num_classes=num_classes)) + +train_itag_file = '/your/itag/train/file.manifest' # or oss://your_bucket/data/train.manifest +test_itag_file = '/your/itag/test/file.manifest' # oss://your_bucket/data/test.manifest + +image_size2 = 224 +image_size1 = int((256 / 224) * image_size2) +data_source_type = 'ClsSourceItag' +dataset_type = 'ClsDataset' +img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +train_pipeline = [ + dict(type='RandomResizedCrop', size=image_size2), + dict(type='RandomHorizontalFlip'), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Collect', keys=['img', 'gt_labels']) +] +test_pipeline = [ + dict(type='Resize', size=image_size1), + dict(type='CenterCrop', size=image_size2), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Collect', keys=['img', 'gt_labels']) +] + +data = dict( + imgs_per_gpu=32, # total 256 + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_source=dict( + type=data_source_type, + list_file=train_itag_file, + class_list=CLASSES, + ), + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_source=dict( + type=data_source_type, + list_file=test_itag_file, + class_list=CLASSES), + pipeline=test_pipeline)) + +eval_config = dict(initial=False, interval=1, gpu_collect=True) +eval_pipelines = [ + dict( + mode='test', + data=data['val'], + dist_eval=True, + evaluators=[dict(type='ClsEvaluator', topk=(1, ))], + ) +] + +# optimizer +optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001) + +# learning policy +lr_config = dict(policy='step', step=[30, 60, 90]) +checkpoint_config = dict(interval=5) + +# runtime settings +total_epochs = 100 +checkpoint_sync_export = True +export = dict(export_neck=True) diff --git a/easycv/apis/export.py b/easycv/apis/export.py index 9075c61a..f4529d2b 100644 --- a/easycv/apis/export.py +++ b/easycv/apis/export.py @@ -116,6 +116,8 @@ def _export_cls(model, cfg, filename): class_list = io.open(label_map_path).readlines() elif hasattr(cfg, 'class_list'): class_list = cfg.class_list + elif hasattr(cfg, 'CLASSES'): + class_list = cfg.CLASSES model_config = dict( type='Classification', diff --git a/easycv/datasets/classification/data_sources/image_list.py b/easycv/datasets/classification/data_sources/image_list.py index bede4761..2c728145 100644 --- a/easycv/datasets/classification/data_sources/image_list.py +++ b/easycv/datasets/classification/data_sources/image_list.py @@ -8,7 +8,7 @@ from easycv.datasets.registry import DATASOURCES from easycv.file import io from easycv.file.image import load_image -from easycv.framework.errors import TypeError +from easycv.framework.errors import TypeError, ValueError from easycv.utils.dist_utils import dist_zero_exec from .utils import split_listfile_byrank @@ -28,7 +28,6 @@ class ClsSourceImageList(object): If split, data list will be split to each rank. split_label_balance: if `split_huge_listfile_byrank` is true, whether split with label balance cache_path: if `split_huge_listfile_byrank` is true, cache list_file will be saved to cache_path. - max_try: int, max try numbers of reading image """ def __init__(self, @@ -37,13 +36,9 @@ def __init__(self, delimeter=' ', split_huge_listfile_byrank=False, split_label_balance=False, - cache_path='data/', - max_try=20): + cache_path='data/'): ImageFile.LOAD_TRUNCATED_IMAGES = True - - self.max_try = max_try - # DistributedMPSampler need this attr self.has_labels = True @@ -124,77 +119,39 @@ class ClsSourceItag(ClsSourceImageList): list_file : str / list(str), str means a input image list file path, this file contains records as `image_path label` in list_file list(str) means multi image list, each one contains some records as `image_path label` - root: str / list(str), root path for image_path, each list_file will need a root, - if len(root) < len(list_file), we will use root[-1] to fill root list. - delimeter: str, delimeter of each line in the `list_file` - split_huge_listfile_byrank: Adapt to the situation that the memory cannot fully load a huge amount of data list. - If split, data list will be split to each rank. - split_label_balance: if `split_huge_listfile_byrank` is true, whether split with label balance - cache_path: if `split_huge_listfile_byrank` is true, cache list_file will be saved to cache_path. - max_try: int, max try numbers of reading image """ - def __init__(self, - list_file, - root='', - delimeter=' ', - split_huge_listfile_byrank=False, - split_label_balance=False, - cache_path='data/', - max_try=20): - + def __init__(self, list_file, root='', class_list=None): + assert root is None or len( + root) < 1, 'The "root" param is not used and will be removed soon!' ImageFile.LOAD_TRUNCATED_IMAGES = True - - self.max_try = max_try - # DistributedMPSampler need this attr self.has_labels = True - - if isinstance(list_file, str): - assert isinstance(root, str), 'list_file is str, root must be str' - list_file = [list_file] - root = [root] + self.class_list = class_list + if self.class_list is None: + logging.warning( + 'It is recommended to specify the ``class_list`` parameter!') + self._auto_collect_labels = True + self.label_dict = {} else: - assert isinstance(list_file, list), \ - 'list_file should be str or list(str)' - root = [root] if isinstance(root, str) else root - if not isinstance(root, list): - raise TypeError('root must be str or list(str), but get %s' % - type(root)) - - if len(root) < len(list_file): - logging.warning( - 'len(root) < len(list_file), fill root with root last!') - root = root + [root[-1]] * (len(list_file) - len(root)) - - # TODO: support return list, donot save split file - # TODO: support loading list_file that have already been split - if split_huge_listfile_byrank: - with dist_zero_exec(): - list_file = split_listfile_byrank( - list_file=list_file, - label_balance=split_label_balance, - save_path=cache_path) - - self.fns = [] - self.labels = [] - label_dict = dict() - for l, r in zip(list_file, root): - fns, labels, label_dict = self.parse_list_file(l, label_dict) - self.fns += fns - self.labels += labels + self.label_dict = dict( + zip(self.class_list, range(len(self.class_list)))) + self._auto_collect_labels = False + self.fns, self.labels, self.label_dict = self.parse_list_file( + list_file, self.label_dict, self._auto_collect_labels) @staticmethod - def parse_list_file(list_file, label_dict): - with open(list_file, 'r', encoding='utf-8') as f: - data = f.readlines() + def parse_list_file(list_file, label_dict, auto_collect_labels=True): + with io.open(list_file, 'r') as f: + rows = f.read().splitlines() fns = [] - labels = [] - for i in range(len(data)): - data_i = json.loads(data[i]) + labels_id = [] + + for row_str in rows: + data_i = json.loads(row_str.strip()) img_path = data_i['data']['source'] - label = [] + label_id = [] priority = 2 for k in data_i.keys(): @@ -206,26 +163,27 @@ def parse_list_file(list_file, label_dict): for k, v in data_i.items(): if 'label' in k: - label = [] + label_id = [] result_list = v['results'] for j in range(len(result_list)): - anno_list = result_list[j]['data'] - if 'labels' in anno_list: - if anno_list['labels']['单选'] not in label_dict: - label_dict[anno_list['labels']['单选']] = len( - label_dict) - label.append(label_dict[anno_list['labels']['单选']]) - else: - if anno_list not in label_dict: - label_dict[anno_list] = len(label_dict) - label.append(label_dict[anno_list]) + label = result_list[j]['data'] + if 'labels' in label: + label = label['labels']['单选'] + if label not in label_dict: + if auto_collect_labels: + label_dict[label] = len(label_dict) + else: + raise ValueError( + f'Not find label "{label}" in label dict: {label_dict}' + ) + label_id.append(label_dict[label]) if 'verify' in k: break elif 'check' in k and priority == 1: break fns.append(img_path) - labels.append( - label[0]) if len(label) == 1 else labels.append(label) + labels_id.append(label_id[0]) if len( + label_id) == 1 else labels_id.append(label_id) - return fns, labels, label_dict + return fns, labels_id, label_dict diff --git a/easycv/predictors/classifier.py b/easycv/predictors/classifier.py index eedcdb5e..facd9feb 100644 --- a/easycv/predictors/classifier.py +++ b/easycv/predictors/classifier.py @@ -210,13 +210,17 @@ def __init__(self, model_config: config string for model to init, in json format """ self.predictor = Predictor(model_path) - if 'class_list' not in self.predictor.cfg and label_map_path is None: + if 'class_list' not in self.predictor.cfg and \ + 'CLASSES' not in self.predictor.cfg and \ + label_map_path is None: raise ValueError( - "label_map_path need to be set, when ckpt doesn't contain class_list" + "'label_map_path' need to be set, when ckpt doesn't contain key 'class_list' and 'CLASSES'!" ) if label_map_path is None: class_list = self.predictor.cfg.get('class_list', []) + if len(class_list) < 1: + class_list = self.predictor.cfg.get('CLASSES', []) self.label_map = [i.strip() for i in class_list] else: class_list = open(label_map_path).readlines() diff --git a/tests/datasets/classification/data_sources/test_cls_itag_datasource.py b/tests/datasets/classification/data_sources/test_cls_itag_datasource.py new file mode 100644 index 00000000..e3b9d25f --- /dev/null +++ b/tests/datasets/classification/data_sources/test_cls_itag_datasource.py @@ -0,0 +1,92 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from tests.ut_config import CLS_DATA_ITAG_OSS + +from easycv.datasets.builder import build_datasource +from easycv.framework.errors import ValueError + + +class ClsSourceImageListTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + def test_default(self): + from easycv.file import io + io.access_oss() + + cfg = dict(type='ClsSourceItag', list_file=CLS_DATA_ITAG_OSS) + data_source = build_datasource(cfg) + + index_list = list(range(5)) + for idx in index_list: + results = data_source[idx] + img = results['img'] + label = results['gt_labels'] + self.assertEqual(img.mode, 'RGB') + self.assertIn(label, list(range(3))) + img.close() + + self.assertEqual(len(data_source), 11) + self.assertDictEqual(data_source.label_dict, { + 'ng': 0, + 'ok': 1, + '中文': 2 + }) + + def test_with_class_list(self): + from easycv.file import io + io.access_oss() + + cfg = dict( + type='ClsSourceItag', + class_list=['中文', 'ng', 'ok'], + list_file=CLS_DATA_ITAG_OSS) + data_source = build_datasource(cfg) + + index_list = list(range(5)) + for idx in index_list: + results = data_source[idx] + img = results['img'] + label = results['gt_labels'] + self.assertEqual(img.mode, 'RGB') + self.assertIn(label, list(range(3))) + img.close() + + self.assertEqual(len(data_source), 11) + self.assertDictEqual(data_source.label_dict, { + '中文': 0, + 'ng': 1, + 'ok': 2 + }) + + def test_with_fault_class_list(self): + from easycv.file import io + io.access_oss() + + with self.assertRaises(ValueError) as cm: + cfg = dict( + type='ClsSourceItag', + class_list=['error', 'ng', 'ok'], + list_file=CLS_DATA_ITAG_OSS) + + data_source = build_datasource(cfg) + index_list = list(range(5)) + for idx in index_list: + results = data_source[idx] + img = results['img'] + label = results['gt_labels'] + self.assertEqual(img.mode, 'RGB') + self.assertIn(label, list(range(3))) + img.close() + + exception = cm.exception + self.assertEqual( + exception.message, + "Not find label \"中文\" in label dict: {'error': 0, 'ng': 1, 'ok': 2}" + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ut_config.py b/tests/ut_config.py index 350792a1..13dfa15f 100644 --- a/tests/ut_config.py +++ b/tests/ut_config.py @@ -68,6 +68,10 @@ TMP_DIR_OSS = os.path.join(BASE_OSS_PATH, 'tmp') TMP_DIR_LOCAL = os.path.join(BASE_LOCAL_PATH, 'tmp') +CLS_DATA_ITAG_OSS = os.path.join( + BASE_OSS_PATH, + 'local_backup/easycv_nfs/data/classification/cls_itagtest/cls_itagtest.manifest' +) CLS_DATA_NPY_LOCAL = os.path.join(BASE_LOCAL_PATH, 'data/classification/npy/') SMALL_IMAGENET_RAW_LOCAL = os.path.join( BASE_LOCAL_PATH, 'data/classification/small_imagenet_raw') diff --git a/tools/train.py b/tools/train.py index 2b44c7c4..55d0b438 100644 --- a/tools/train.py +++ b/tools/train.py @@ -162,9 +162,9 @@ def main(): else: cfg.oss_work_dir = None - if args.resume_from is not None: + if args.resume_from is not None and len(args.resume_from) > 0: cfg.resume_from = args.resume_from - if args.load_from is not None: + if args.load_from is not None and len(args.load_from) > 0: cfg.load_from = args.load_from # dynamic adapt mmdet models