-
Notifications
You must be signed in to change notification settings - Fork 206
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add mobilenet itag config
- Loading branch information
Showing
7 changed files
with
233 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
_base_ = '../imagenet/common/classification_base.py' | ||
|
||
# oss_io_config = dict(ak_id='', # your oss ak id | ||
# ak_secret='', # your oss ak secret | ||
# hosts='', # your oss hosts | ||
# buckets=[]) # your oss bucket name | ||
|
||
# Ensure the CLASSES definition is in one line, for adapt to its replacement by user_config_params. | ||
# yapf:disable | ||
CLASSES = ['label1', 'label2', 'label3'] # replace with your true lables of itag manifest file | ||
num_classes = 3 | ||
# model settings | ||
model = dict( | ||
type='Classification', | ||
backbone=dict(type='MobileNetV2'), | ||
head=dict( | ||
type='ClsHead', | ||
with_avg_pool=True, | ||
in_channels=1280, | ||
loss_config=dict( | ||
type='CrossEntropyLossWithLabelSmooth', | ||
label_smooth=0, | ||
), | ||
num_classes=num_classes)) | ||
|
||
train_itag_file = '/your/itag/train/file.manifest' # or oss://your_bucket/data/train.manifest | ||
test_itag_file = '/your/itag/test/file.manifest' # oss://your_bucket/data/test.manifest | ||
|
||
image_size2 = 224 | ||
image_size1 = int((256 / 224) * image_size2) | ||
data_source_type = 'ClsSourceItag' | ||
dataset_type = 'ClsDataset' | ||
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | ||
train_pipeline = [ | ||
dict(type='RandomResizedCrop', size=image_size2), | ||
dict(type='RandomHorizontalFlip'), | ||
dict(type='ToTensor'), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='Collect', keys=['img', 'gt_labels']) | ||
] | ||
test_pipeline = [ | ||
dict(type='Resize', size=image_size1), | ||
dict(type='CenterCrop', size=image_size2), | ||
dict(type='ToTensor'), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='Collect', keys=['img', 'gt_labels']) | ||
] | ||
|
||
data = dict( | ||
imgs_per_gpu=32, # total 256 | ||
workers_per_gpu=4, | ||
train=dict( | ||
type=dataset_type, | ||
data_source=dict( | ||
type=data_source_type, | ||
list_file=train_itag_file, | ||
class_list=CLASSES, | ||
), | ||
pipeline=train_pipeline), | ||
val=dict( | ||
type=dataset_type, | ||
data_source=dict( | ||
type=data_source_type, | ||
list_file=test_itag_file, | ||
class_list=CLASSES), | ||
pipeline=test_pipeline)) | ||
|
||
eval_config = dict(initial=False, interval=1, gpu_collect=True) | ||
eval_pipelines = [ | ||
dict( | ||
mode='test', | ||
data=data['val'], | ||
dist_eval=True, | ||
evaluators=[dict(type='ClsEvaluator', topk=(1, ))], | ||
) | ||
] | ||
|
||
# optimizer | ||
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001) | ||
|
||
# learning policy | ||
lr_config = dict(policy='step', step=[30, 60, 90]) | ||
checkpoint_config = dict(interval=5) | ||
|
||
# runtime settings | ||
total_epochs = 100 | ||
checkpoint_sync_export = True | ||
export = dict(export_neck=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
92 changes: 92 additions & 0 deletions
92
tests/datasets/classification/data_sources/test_cls_itag_datasource.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
import unittest | ||
|
||
from tests.ut_config import CLS_DATA_ITAG_OSS | ||
|
||
from easycv.datasets.builder import build_datasource | ||
from easycv.framework.errors import ValueError | ||
|
||
|
||
class ClsSourceImageListTest(unittest.TestCase): | ||
|
||
def setUp(self): | ||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | ||
|
||
def test_default(self): | ||
from easycv.file import io | ||
io.access_oss() | ||
|
||
cfg = dict(type='ClsSourceItag', list_file=CLS_DATA_ITAG_OSS) | ||
data_source = build_datasource(cfg) | ||
|
||
index_list = list(range(5)) | ||
for idx in index_list: | ||
results = data_source[idx] | ||
img = results['img'] | ||
label = results['gt_labels'] | ||
self.assertEqual(img.mode, 'RGB') | ||
self.assertIn(label, list(range(3))) | ||
img.close() | ||
|
||
self.assertEqual(len(data_source), 11) | ||
self.assertDictEqual(data_source.label_dict, { | ||
'ng': 0, | ||
'ok': 1, | ||
'中文': 2 | ||
}) | ||
|
||
def test_with_class_list(self): | ||
from easycv.file import io | ||
io.access_oss() | ||
|
||
cfg = dict( | ||
type='ClsSourceItag', | ||
class_list=['中文', 'ng', 'ok'], | ||
list_file=CLS_DATA_ITAG_OSS) | ||
data_source = build_datasource(cfg) | ||
|
||
index_list = list(range(5)) | ||
for idx in index_list: | ||
results = data_source[idx] | ||
img = results['img'] | ||
label = results['gt_labels'] | ||
self.assertEqual(img.mode, 'RGB') | ||
self.assertIn(label, list(range(3))) | ||
img.close() | ||
|
||
self.assertEqual(len(data_source), 11) | ||
self.assertDictEqual(data_source.label_dict, { | ||
'中文': 0, | ||
'ng': 1, | ||
'ok': 2 | ||
}) | ||
|
||
def test_with_fault_class_list(self): | ||
from easycv.file import io | ||
io.access_oss() | ||
|
||
with self.assertRaises(ValueError) as cm: | ||
cfg = dict( | ||
type='ClsSourceItag', | ||
class_list=['error', 'ng', 'ok'], | ||
list_file=CLS_DATA_ITAG_OSS) | ||
|
||
data_source = build_datasource(cfg) | ||
index_list = list(range(5)) | ||
for idx in index_list: | ||
results = data_source[idx] | ||
img = results['img'] | ||
label = results['gt_labels'] | ||
self.assertEqual(img.mode, 'RGB') | ||
self.assertIn(label, list(range(3))) | ||
img.close() | ||
|
||
exception = cm.exception | ||
self.assertEqual( | ||
exception.message, | ||
"Not find label \"中文\" in label dict: {'error': 0, 'ng': 1, 'ok': 2}" | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.