Skip to content

Commit

Permalink
add mobilenet itag config (#276)
Browse files Browse the repository at this point in the history
* add mobilenet itag config
  • Loading branch information
Cathy0908 authored Feb 8, 2023
1 parent dd4e6bd commit edfd482
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 85 deletions.
88 changes: 88 additions & 0 deletions configs/classification/itag/mobilenetv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
_base_ = '../imagenet/common/classification_base.py'

# oss_io_config = dict(ak_id='', # your oss ak id
# ak_secret='', # your oss ak secret
# hosts='', # your oss hosts
# buckets=[]) # your oss bucket name

# Ensure the CLASSES definition is in one line, for adapt to its replacement by user_config_params.
# yapf:disable
CLASSES = ['label1', 'label2', 'label3'] # replace with your true lables of itag manifest file
num_classes = 3
# model settings
model = dict(
type='Classification',
backbone=dict(type='MobileNetV2'),
head=dict(
type='ClsHead',
with_avg_pool=True,
in_channels=1280,
loss_config=dict(
type='CrossEntropyLossWithLabelSmooth',
label_smooth=0,
),
num_classes=num_classes))

train_itag_file = '/your/itag/train/file.manifest' # or oss://your_bucket/data/train.manifest
test_itag_file = '/your/itag/test/file.manifest' # oss://your_bucket/data/test.manifest

image_size2 = 224
image_size1 = int((256 / 224) * image_size2)
data_source_type = 'ClsSourceItag'
dataset_type = 'ClsDataset'
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_pipeline = [
dict(type='RandomResizedCrop', size=image_size2),
dict(type='RandomHorizontalFlip'),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Collect', keys=['img', 'gt_labels'])
]
test_pipeline = [
dict(type='Resize', size=image_size1),
dict(type='CenterCrop', size=image_size2),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Collect', keys=['img', 'gt_labels'])
]

data = dict(
imgs_per_gpu=32, # total 256
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_source=dict(
type=data_source_type,
list_file=train_itag_file,
class_list=CLASSES,
),
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_source=dict(
type=data_source_type,
list_file=test_itag_file,
class_list=CLASSES),
pipeline=test_pipeline))

eval_config = dict(initial=False, interval=1, gpu_collect=True)
eval_pipelines = [
dict(
mode='test',
data=data['val'],
dist_eval=True,
evaluators=[dict(type='ClsEvaluator', topk=(1, ))],
)
]

# optimizer
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)

# learning policy
lr_config = dict(policy='step', step=[30, 60, 90])
checkpoint_config = dict(interval=5)

# runtime settings
total_epochs = 100
checkpoint_sync_export = True
export = dict(export_neck=True)
2 changes: 2 additions & 0 deletions easycv/apis/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ def _export_cls(model, cfg, filename):
class_list = io.open(label_map_path).readlines()
elif hasattr(cfg, 'class_list'):
class_list = cfg.class_list
elif hasattr(cfg, 'CLASSES'):
class_list = cfg.CLASSES

model_config = dict(
type='Classification',
Expand Down
120 changes: 39 additions & 81 deletions easycv/datasets/classification/data_sources/image_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from easycv.datasets.registry import DATASOURCES
from easycv.file import io
from easycv.file.image import load_image
from easycv.framework.errors import TypeError
from easycv.framework.errors import TypeError, ValueError
from easycv.utils.dist_utils import dist_zero_exec
from .utils import split_listfile_byrank

Expand All @@ -28,7 +28,6 @@ class ClsSourceImageList(object):
If split, data list will be split to each rank.
split_label_balance: if `split_huge_listfile_byrank` is true, whether split with label balance
cache_path: if `split_huge_listfile_byrank` is true, cache list_file will be saved to cache_path.
max_try: int, max try numbers of reading image
"""

def __init__(self,
Expand All @@ -37,13 +36,9 @@ def __init__(self,
delimeter=' ',
split_huge_listfile_byrank=False,
split_label_balance=False,
cache_path='data/',
max_try=20):
cache_path='data/'):

ImageFile.LOAD_TRUNCATED_IMAGES = True

self.max_try = max_try

# DistributedMPSampler need this attr
self.has_labels = True

Expand Down Expand Up @@ -124,77 +119,39 @@ class ClsSourceItag(ClsSourceImageList):
list_file : str / list(str), str means a input image list file path,
this file contains records as `image_path label` in list_file
list(str) means multi image list, each one contains some records as `image_path label`
root: str / list(str), root path for image_path, each list_file will need a root,
if len(root) < len(list_file), we will use root[-1] to fill root list.
delimeter: str, delimeter of each line in the `list_file`
split_huge_listfile_byrank: Adapt to the situation that the memory cannot fully load a huge amount of data list.
If split, data list will be split to each rank.
split_label_balance: if `split_huge_listfile_byrank` is true, whether split with label balance
cache_path: if `split_huge_listfile_byrank` is true, cache list_file will be saved to cache_path.
max_try: int, max try numbers of reading image
"""

def __init__(self,
list_file,
root='',
delimeter=' ',
split_huge_listfile_byrank=False,
split_label_balance=False,
cache_path='data/',
max_try=20):

def __init__(self, list_file, root='', class_list=None):
assert root is None or len(
root) < 1, 'The "root" param is not used and will be removed soon!'
ImageFile.LOAD_TRUNCATED_IMAGES = True

self.max_try = max_try

# DistributedMPSampler need this attr
self.has_labels = True

if isinstance(list_file, str):
assert isinstance(root, str), 'list_file is str, root must be str'
list_file = [list_file]
root = [root]
self.class_list = class_list
if self.class_list is None:
logging.warning(
'It is recommended to specify the ``class_list`` parameter!')
self._auto_collect_labels = True
self.label_dict = {}
else:
assert isinstance(list_file, list), \
'list_file should be str or list(str)'
root = [root] if isinstance(root, str) else root
if not isinstance(root, list):
raise TypeError('root must be str or list(str), but get %s' %
type(root))

if len(root) < len(list_file):
logging.warning(
'len(root) < len(list_file), fill root with root last!')
root = root + [root[-1]] * (len(list_file) - len(root))

# TODO: support return list, donot save split file
# TODO: support loading list_file that have already been split
if split_huge_listfile_byrank:
with dist_zero_exec():
list_file = split_listfile_byrank(
list_file=list_file,
label_balance=split_label_balance,
save_path=cache_path)

self.fns = []
self.labels = []
label_dict = dict()
for l, r in zip(list_file, root):
fns, labels, label_dict = self.parse_list_file(l, label_dict)
self.fns += fns
self.labels += labels
self.label_dict = dict(
zip(self.class_list, range(len(self.class_list))))
self._auto_collect_labels = False
self.fns, self.labels, self.label_dict = self.parse_list_file(
list_file, self.label_dict, self._auto_collect_labels)

@staticmethod
def parse_list_file(list_file, label_dict):
with open(list_file, 'r', encoding='utf-8') as f:
data = f.readlines()
def parse_list_file(list_file, label_dict, auto_collect_labels=True):
with io.open(list_file, 'r') as f:
rows = f.read().splitlines()

fns = []
labels = []
for i in range(len(data)):
data_i = json.loads(data[i])
labels_id = []

for row_str in rows:
data_i = json.loads(row_str.strip())
img_path = data_i['data']['source']
label = []
label_id = []

priority = 2
for k in data_i.keys():
Expand All @@ -206,26 +163,27 @@ def parse_list_file(list_file, label_dict):

for k, v in data_i.items():
if 'label' in k:
label = []
label_id = []
result_list = v['results']
for j in range(len(result_list)):
anno_list = result_list[j]['data']
if 'labels' in anno_list:
if anno_list['labels']['单选'] not in label_dict:
label_dict[anno_list['labels']['单选']] = len(
label_dict)
label.append(label_dict[anno_list['labels']['单选']])
else:
if anno_list not in label_dict:
label_dict[anno_list] = len(label_dict)
label.append(label_dict[anno_list])
label = result_list[j]['data']
if 'labels' in label:
label = label['labels']['单选']
if label not in label_dict:
if auto_collect_labels:
label_dict[label] = len(label_dict)
else:
raise ValueError(
f'Not find label "{label}" in label dict: {label_dict}'
)
label_id.append(label_dict[label])
if 'verify' in k:
break
elif 'check' in k and priority == 1:
break

fns.append(img_path)
labels.append(
label[0]) if len(label) == 1 else labels.append(label)
labels_id.append(label_id[0]) if len(
label_id) == 1 else labels_id.append(label_id)

return fns, labels, label_dict
return fns, labels_id, label_dict
8 changes: 6 additions & 2 deletions easycv/predictors/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,17 @@ def __init__(self,
model_config: config string for model to init, in json format
"""
self.predictor = Predictor(model_path)
if 'class_list' not in self.predictor.cfg and label_map_path is None:
if 'class_list' not in self.predictor.cfg and \
'CLASSES' not in self.predictor.cfg and \
label_map_path is None:
raise ValueError(
"label_map_path need to be set, when ckpt doesn't contain class_list"
"'label_map_path' need to be set, when ckpt doesn't contain key 'class_list' and 'CLASSES'!"
)

if label_map_path is None:
class_list = self.predictor.cfg.get('class_list', [])
if len(class_list) < 1:
class_list = self.predictor.cfg.get('CLASSES', [])
self.label_map = [i.strip() for i in class_list]
else:
class_list = open(label_map_path).readlines()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from tests.ut_config import CLS_DATA_ITAG_OSS

from easycv.datasets.builder import build_datasource
from easycv.framework.errors import ValueError


class ClsSourceImageListTest(unittest.TestCase):

def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))

def test_default(self):
from easycv.file import io
io.access_oss()

cfg = dict(type='ClsSourceItag', list_file=CLS_DATA_ITAG_OSS)
data_source = build_datasource(cfg)

index_list = list(range(5))
for idx in index_list:
results = data_source[idx]
img = results['img']
label = results['gt_labels']
self.assertEqual(img.mode, 'RGB')
self.assertIn(label, list(range(3)))
img.close()

self.assertEqual(len(data_source), 11)
self.assertDictEqual(data_source.label_dict, {
'ng': 0,
'ok': 1,
'中文': 2
})

def test_with_class_list(self):
from easycv.file import io
io.access_oss()

cfg = dict(
type='ClsSourceItag',
class_list=['中文', 'ng', 'ok'],
list_file=CLS_DATA_ITAG_OSS)
data_source = build_datasource(cfg)

index_list = list(range(5))
for idx in index_list:
results = data_source[idx]
img = results['img']
label = results['gt_labels']
self.assertEqual(img.mode, 'RGB')
self.assertIn(label, list(range(3)))
img.close()

self.assertEqual(len(data_source), 11)
self.assertDictEqual(data_source.label_dict, {
'中文': 0,
'ng': 1,
'ok': 2
})

def test_with_fault_class_list(self):
from easycv.file import io
io.access_oss()

with self.assertRaises(ValueError) as cm:
cfg = dict(
type='ClsSourceItag',
class_list=['error', 'ng', 'ok'],
list_file=CLS_DATA_ITAG_OSS)

data_source = build_datasource(cfg)
index_list = list(range(5))
for idx in index_list:
results = data_source[idx]
img = results['img']
label = results['gt_labels']
self.assertEqual(img.mode, 'RGB')
self.assertIn(label, list(range(3)))
img.close()

exception = cm.exception
self.assertEqual(
exception.message,
"Not find label \"中文\" in label dict: {'error': 0, 'ng': 1, 'ok': 2}"
)


if __name__ == '__main__':
unittest.main()
4 changes: 4 additions & 0 deletions tests/ut_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@
TMP_DIR_OSS = os.path.join(BASE_OSS_PATH, 'tmp')
TMP_DIR_LOCAL = os.path.join(BASE_LOCAL_PATH, 'tmp')

CLS_DATA_ITAG_OSS = os.path.join(
BASE_OSS_PATH,
'local_backup/easycv_nfs/data/classification/cls_itagtest/cls_itagtest.manifest'
)
CLS_DATA_NPY_LOCAL = os.path.join(BASE_LOCAL_PATH, 'data/classification/npy/')
SMALL_IMAGENET_RAW_LOCAL = os.path.join(
BASE_LOCAL_PATH, 'data/classification/small_imagenet_raw')
Expand Down
Loading

0 comments on commit edfd482

Please sign in to comment.