From 23f2b0e39950e1377d29a16ecaa05f582ebc11f1 Mon Sep 17 00:00:00 2001 From: tuofeilun <38110862+tuofeilunhifi@users.noreply.github.com> Date: Thu, 1 Dec 2022 17:47:10 +0800 Subject: [PATCH] Adapt designer (#235) 1. Use original config as startup script. (For details, see refactor config parsing method #225) 2. Refactor the splicing rules of the check_base_cfg_path function in the EasyCV/easycv/utils/config_tools.py 3. Support three ways to pass class_list parameter. 4. Fix the bug that clsevalutor may make mistakes when evaluating top5. 5. Fix the bug that the distributed export cannot export the model. 6. Fix the bug that the load pretrained model key does not match. 7. support cls data source itag. --- .../imagenet/common/classification_base.py | 20 ++ .../common/dataset/imagenet_classification.py | 157 +++++++++ .../imagenet/hrnet/hrnetw18_b32x8_100e_jpg.py | 61 +--- .../resnet/resnet50_b32x8_100e_jpg.py | 174 +--------- .../resnext/resnext50-32x4d_b32x8_100e_jpg.py | 61 +--- ...tiny_patch4_window7_224_b64x16_300e_jpg.py | 55 +-- .../vit_base_patch16_224_b64x64_300e_jpg.py | 54 +-- .../common/dataset/imagenet_metriclearning.py | 167 +++++++++ .../common/metriclearning_base.py | 6 + ...net_timm_modelparallel_softmaxbased_jpg.py | 39 +++ .../imagenet_timm_softmaxbased_jpg.py | 72 ++++ easycv/apis/export.py | 6 + easycv/core/evaluation/classification_eval.py | 5 + .../classification/data_sources/__init__.py | 6 +- .../classification/data_sources/image_list.py | 121 ++++++- .../models/classification/classification.py | 5 +- easycv/models/modelzoo.py | 19 ++ easycv/models/utils/multi_pooling.py | 4 - easycv/toolkit/hpo/det/config_dlc.ini | 2 +- easycv/toolkit/hpo/det/config_local.ini | 2 +- .../toolkit/hpo/det/fcos_r50_torch_1x_coco.py | 192 ----------- easycv/utils/checkpoint.py | 11 +- easycv/utils/config_tools.py | 317 +++++++++++++++--- requirements/runtime.txt | 2 +- setup.py | 4 +- tests/configs/test_adapt_pai_params.py | 65 ++++ tests/configs/test_check_base_cfg_path.py | 83 +++++ tests/configs/test_template.py | 1 - tests/hooks/test_export_hook.py | 2 +- tests/tools/test_classification_train.py | 35 +- tests/tools/test_mae_train.py | 1 - tests/ut_config.py | 2 + tools/train.py | 32 +- 33 files changed, 1113 insertions(+), 670 deletions(-) create mode 100644 configs/classification/imagenet/common/classification_base.py create mode 100644 configs/classification/imagenet/common/dataset/imagenet_classification.py create mode 100644 configs/metric_learning/common/dataset/imagenet_metriclearning.py create mode 100644 configs/metric_learning/common/metriclearning_base.py create mode 100644 configs/metric_learning/imagenet_timm_modelparallel_softmaxbased_jpg.py create mode 100644 configs/metric_learning/imagenet_timm_softmaxbased_jpg.py delete mode 100644 easycv/toolkit/hpo/det/fcos_r50_torch_1x_coco.py create mode 100644 tests/configs/test_adapt_pai_params.py create mode 100644 tests/configs/test_check_base_cfg_path.py delete mode 100644 tests/configs/test_template.py diff --git a/configs/classification/imagenet/common/classification_base.py b/configs/classification/imagenet/common/classification_base.py new file mode 100644 index 00000000..4bbd3183 --- /dev/null +++ b/configs/classification/imagenet/common/classification_base.py @@ -0,0 +1,20 @@ +_base_ = 'configs/base.py' + +log_config = dict( + interval=10, + hooks=[dict(type='TextLoggerHook'), + dict(type='TensorboardLoggerHook')]) + +image_size2 = 224 +image_size1 = int((256 / 224) * image_size2) +img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + +predict = dict( + type='ClassificationPredictor', + pipelines=[ + dict(type='Resize', size=image_size1), + dict(type='CenterCrop', size=image_size2), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Collect', keys=['img']) + ]) diff --git a/configs/classification/imagenet/common/dataset/imagenet_classification.py b/configs/classification/imagenet/common/dataset/imagenet_classification.py new file mode 100644 index 00000000..2d385e07 --- /dev/null +++ b/configs/classification/imagenet/common/dataset/imagenet_classification.py @@ -0,0 +1,157 @@ +_base_ = '../classification_base.py' + +class_list = [ + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', + '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', + '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', + '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', + '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', + '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', + '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', + '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', + '98', '99', '100', '101', '102', '103', '104', '105', '106', '107', '108', + '109', '110', '111', '112', '113', '114', '115', '116', '117', '118', + '119', '120', '121', '122', '123', '124', '125', '126', '127', '128', + '129', '130', '131', '132', '133', '134', '135', '136', '137', '138', + '139', '140', '141', '142', '143', '144', '145', '146', '147', '148', + '149', '150', '151', '152', '153', '154', '155', '156', '157', '158', + '159', '160', '161', '162', '163', '164', '165', '166', '167', '168', + '169', '170', '171', '172', '173', '174', '175', '176', '177', '178', + '179', '180', '181', '182', '183', '184', '185', '186', '187', '188', + '189', '190', '191', '192', '193', '194', '195', '196', '197', '198', + '199', '200', '201', '202', '203', '204', '205', '206', '207', '208', + '209', '210', '211', '212', '213', '214', '215', '216', '217', '218', + '219', '220', '221', '222', '223', '224', '225', '226', '227', '228', + '229', '230', '231', '232', '233', '234', '235', '236', '237', '238', + '239', '240', '241', '242', '243', '244', '245', '246', '247', '248', + '249', '250', '251', '252', '253', '254', '255', '256', '257', '258', + '259', '260', '261', '262', '263', '264', '265', '266', '267', '268', + '269', '270', '271', '272', '273', '274', '275', '276', '277', '278', + '279', '280', '281', '282', '283', '284', '285', '286', '287', '288', + '289', '290', '291', '292', '293', '294', '295', '296', '297', '298', + '299', '300', '301', '302', '303', '304', '305', '306', '307', '308', + '309', '310', '311', '312', '313', '314', '315', '316', '317', '318', + '319', '320', '321', '322', '323', '324', '325', '326', '327', '328', + '329', '330', '331', '332', '333', '334', '335', '336', '337', '338', + '339', '340', '341', '342', '343', '344', '345', '346', '347', '348', + '349', '350', '351', '352', '353', '354', '355', '356', '357', '358', + '359', '360', '361', '362', '363', '364', '365', '366', '367', '368', + '369', '370', '371', '372', '373', '374', '375', '376', '377', '378', + '379', '380', '381', '382', '383', '384', '385', '386', '387', '388', + '389', '390', '391', '392', '393', '394', '395', '396', '397', '398', + '399', '400', '401', '402', '403', '404', '405', '406', '407', '408', + '409', '410', '411', '412', '413', '414', '415', '416', '417', '418', + '419', '420', '421', '422', '423', '424', '425', '426', '427', '428', + '429', '430', '431', '432', '433', '434', '435', '436', '437', '438', + '439', '440', '441', '442', '443', '444', '445', '446', '447', '448', + '449', '450', '451', '452', '453', '454', '455', '456', '457', '458', + '459', '460', '461', '462', '463', '464', '465', '466', '467', '468', + '469', '470', '471', '472', '473', '474', '475', '476', '477', '478', + '479', '480', '481', '482', '483', '484', '485', '486', '487', '488', + '489', '490', '491', '492', '493', '494', '495', '496', '497', '498', + '499', '500', '501', '502', '503', '504', '505', '506', '507', '508', + '509', '510', '511', '512', '513', '514', '515', '516', '517', '518', + '519', '520', '521', '522', '523', '524', '525', '526', '527', '528', + '529', '530', '531', '532', '533', '534', '535', '536', '537', '538', + '539', '540', '541', '542', '543', '544', '545', '546', '547', '548', + '549', '550', '551', '552', '553', '554', '555', '556', '557', '558', + '559', '560', '561', '562', '563', '564', '565', '566', '567', '568', + '569', '570', '571', '572', '573', '574', '575', '576', '577', '578', + '579', '580', '581', '582', '583', '584', '585', '586', '587', '588', + '589', '590', '591', '592', '593', '594', '595', '596', '597', '598', + '599', '600', '601', '602', '603', '604', '605', '606', '607', '608', + '609', '610', '611', '612', '613', '614', '615', '616', '617', '618', + '619', '620', '621', '622', '623', '624', '625', '626', '627', '628', + '629', '630', '631', '632', '633', '634', '635', '636', '637', '638', + '639', '640', '641', '642', '643', '644', '645', '646', '647', '648', + '649', '650', '651', '652', '653', '654', '655', '656', '657', '658', + '659', '660', '661', '662', '663', '664', '665', '666', '667', '668', + '669', '670', '671', '672', '673', '674', '675', '676', '677', '678', + '679', '680', '681', '682', '683', '684', '685', '686', '687', '688', + '689', '690', '691', '692', '693', '694', '695', '696', '697', '698', + '699', '700', '701', '702', '703', '704', '705', '706', '707', '708', + '709', '710', '711', '712', '713', '714', '715', '716', '717', '718', + '719', '720', '721', '722', '723', '724', '725', '726', '727', '728', + '729', '730', '731', '732', '733', '734', '735', '736', '737', '738', + '739', '740', '741', '742', '743', '744', '745', '746', '747', '748', + '749', '750', '751', '752', '753', '754', '755', '756', '757', '758', + '759', '760', '761', '762', '763', '764', '765', '766', '767', '768', + '769', '770', '771', '772', '773', '774', '775', '776', '777', '778', + '779', '780', '781', '782', '783', '784', '785', '786', '787', '788', + '789', '790', '791', '792', '793', '794', '795', '796', '797', '798', + '799', '800', '801', '802', '803', '804', '805', '806', '807', '808', + '809', '810', '811', '812', '813', '814', '815', '816', '817', '818', + '819', '820', '821', '822', '823', '824', '825', '826', '827', '828', + '829', '830', '831', '832', '833', '834', '835', '836', '837', '838', + '839', '840', '841', '842', '843', '844', '845', '846', '847', '848', + '849', '850', '851', '852', '853', '854', '855', '856', '857', '858', + '859', '860', '861', '862', '863', '864', '865', '866', '867', '868', + '869', '870', '871', '872', '873', '874', '875', '876', '877', '878', + '879', '880', '881', '882', '883', '884', '885', '886', '887', '888', + '889', '890', '891', '892', '893', '894', '895', '896', '897', '898', + '899', '900', '901', '902', '903', '904', '905', '906', '907', '908', + '909', '910', '911', '912', '913', '914', '915', '916', '917', '918', + '919', '920', '921', '922', '923', '924', '925', '926', '927', '928', + '929', '930', '931', '932', '933', '934', '935', '936', '937', '938', + '939', '940', '941', '942', '943', '944', '945', '946', '947', '948', + '949', '950', '951', '952', '953', '954', '955', '956', '957', '958', + '959', '960', '961', '962', '963', '964', '965', '966', '967', '968', + '969', '970', '971', '972', '973', '974', '975', '976', '977', '978', + '979', '980', '981', '982', '983', '984', '985', '986', '987', '988', + '989', '990', '991', '992', '993', '994', '995', '996', '997', '998', '999' +] + +data_source_type = 'ClsSourceImageList' +data_train_list = 'data/imagenet_raw/meta/train_labeled.txt' +data_train_root = 'data/imagenet_raw/train/' +data_test_list = 'data/imagenet_raw/meta/val_labeled.txt' +data_test_root = 'data/imagenet_raw/validation/' +image_size2 = 224 +image_size1 = int((256 / 224) * image_size2) + +dataset_type = 'ClsDataset' +img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +train_pipeline = [ + dict(type='RandomResizedCrop', size=image_size2), + dict(type='RandomHorizontalFlip'), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Collect', keys=['img', 'gt_labels']) +] +test_pipeline = [ + dict(type='Resize', size=image_size1), + dict(type='CenterCrop', size=image_size2), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Collect', keys=['img', 'gt_labels']) +] + +data = dict( + imgs_per_gpu=32, # total 256 + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_source=dict( + list_file=data_train_list, + root=data_train_root, + type=data_source_type), + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_source=dict( + list_file=data_test_list, + root=data_test_root, + type=data_source_type), + pipeline=test_pipeline)) + +eval_config = dict(initial=False, interval=1, gpu_collect=True) +eval_pipelines = [ + dict( + mode='test', + data=data['val'], + dist_eval=True, + evaluators=[ + dict(type='ClsEvaluator', topk=(1, 5), class_list=class_list) + ], + ) +] diff --git a/configs/classification/imagenet/hrnet/hrnetw18_b32x8_100e_jpg.py b/configs/classification/imagenet/hrnet/hrnetw18_b32x8_100e_jpg.py index 104db80e..118ffa2a 100644 --- a/configs/classification/imagenet/hrnet/hrnetw18_b32x8_100e_jpg.py +++ b/configs/classification/imagenet/hrnet/hrnetw18_b32x8_100e_jpg.py @@ -1,8 +1,4 @@ -_base_ = 'configs/base.py' -log_config = dict( - interval=10, - hooks=[dict(type='TextLoggerHook'), - dict(type='TensorboardLoggerHook')]) +_base_ = '../common/dataset/imagenet_classification.py' # model settings model = dict( @@ -19,61 +15,6 @@ ), num_classes=1000)) -data_train_list = 'data/imagenet_raw/meta/train_labeled.txt' -data_train_root = 'data/imagenet_raw/train/' -data_test_list = 'data/imagenet_raw/meta/val_labeled.txt' -data_test_root = 'data/imagenet_raw/validation/' -data_all_list = 'data/imagenet_raw/meta/all_labeled.txt' -data_root = 'data/imagenet_raw/' - -dataset_type = 'ClsDataset' -img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) -train_pipeline = [ - dict(type='RandomResizedCrop', size=224), - dict(type='RandomHorizontalFlip'), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Collect', keys=['img', 'gt_labels']) -] -test_pipeline = [ - dict(type='Resize', size=256), - dict(type='CenterCrop', size=224), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Collect', keys=['img', 'gt_labels']) -] - -data = dict( - imgs_per_gpu=32, # total 256 - workers_per_gpu=4, - train=dict( - type=dataset_type, - data_source=dict( - list_file=data_train_list, - root=data_train_root, - type='ClsSourceImageList'), - pipeline=train_pipeline), - val=dict( - type=dataset_type, - data_source=dict( - list_file=data_test_list, - root=data_test_root, - type='ClsSourceImageList'), - pipeline=test_pipeline)) - -eval_config = dict(initial=True, interval=100, gpu_collect=True) -eval_pipelines = [ - dict( - mode='test', - data=data['val'], - dist_eval=True, - evaluators=[dict(type='ClsEvaluator', topk=(1, 5))], - ) -] - -# additional hooks -custom_hooks = [] - # optimizer optimizer = dict(type='SGD', lr=0.05, momentum=0.9, weight_decay=0.0001) diff --git a/configs/classification/imagenet/resnet/resnet50_b32x8_100e_jpg.py b/configs/classification/imagenet/resnet/resnet50_b32x8_100e_jpg.py index 0e999255..f67afd08 100644 --- a/configs/classification/imagenet/resnet/resnet50_b32x8_100e_jpg.py +++ b/configs/classification/imagenet/resnet/resnet50_b32x8_100e_jpg.py @@ -1,109 +1,4 @@ -_base_ = '../../../base.py' -class_list = [ - '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', - '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', - '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', - '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', - '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', - '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', - '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', - '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', - '98', '99', '100', '101', '102', '103', '104', '105', '106', '107', '108', - '109', '110', '111', '112', '113', '114', '115', '116', '117', '118', - '119', '120', '121', '122', '123', '124', '125', '126', '127', '128', - '129', '130', '131', '132', '133', '134', '135', '136', '137', '138', - '139', '140', '141', '142', '143', '144', '145', '146', '147', '148', - '149', '150', '151', '152', '153', '154', '155', '156', '157', '158', - '159', '160', '161', '162', '163', '164', '165', '166', '167', '168', - '169', '170', '171', '172', '173', '174', '175', '176', '177', '178', - '179', '180', '181', '182', '183', '184', '185', '186', '187', '188', - '189', '190', '191', '192', '193', '194', '195', '196', '197', '198', - '199', '200', '201', '202', '203', '204', '205', '206', '207', '208', - '209', '210', '211', '212', '213', '214', '215', '216', '217', '218', - '219', '220', '221', '222', '223', '224', '225', '226', '227', '228', - '229', '230', '231', '232', '233', '234', '235', '236', '237', '238', - '239', '240', '241', '242', '243', '244', '245', '246', '247', '248', - '249', '250', '251', '252', '253', '254', '255', '256', '257', '258', - '259', '260', '261', '262', '263', '264', '265', '266', '267', '268', - '269', '270', '271', '272', '273', '274', '275', '276', '277', '278', - '279', '280', '281', '282', '283', '284', '285', '286', '287', '288', - '289', '290', '291', '292', '293', '294', '295', '296', '297', '298', - '299', '300', '301', '302', '303', '304', '305', '306', '307', '308', - '309', '310', '311', '312', '313', '314', '315', '316', '317', '318', - '319', '320', '321', '322', '323', '324', '325', '326', '327', '328', - '329', '330', '331', '332', '333', '334', '335', '336', '337', '338', - '339', '340', '341', '342', '343', '344', '345', '346', '347', '348', - '349', '350', '351', '352', '353', '354', '355', '356', '357', '358', - '359', '360', '361', '362', '363', '364', '365', '366', '367', '368', - '369', '370', '371', '372', '373', '374', '375', '376', '377', '378', - '379', '380', '381', '382', '383', '384', '385', '386', '387', '388', - '389', '390', '391', '392', '393', '394', '395', '396', '397', '398', - '399', '400', '401', '402', '403', '404', '405', '406', '407', '408', - '409', '410', '411', '412', '413', '414', '415', '416', '417', '418', - '419', '420', '421', '422', '423', '424', '425', '426', '427', '428', - '429', '430', '431', '432', '433', '434', '435', '436', '437', '438', - '439', '440', '441', '442', '443', '444', '445', '446', '447', '448', - '449', '450', '451', '452', '453', '454', '455', '456', '457', '458', - '459', '460', '461', '462', '463', '464', '465', '466', '467', '468', - '469', '470', '471', '472', '473', '474', '475', '476', '477', '478', - '479', '480', '481', '482', '483', '484', '485', '486', '487', '488', - '489', '490', '491', '492', '493', '494', '495', '496', '497', '498', - '499', '500', '501', '502', '503', '504', '505', '506', '507', '508', - '509', '510', '511', '512', '513', '514', '515', '516', '517', '518', - '519', '520', '521', '522', '523', '524', '525', '526', '527', '528', - '529', '530', '531', '532', '533', '534', '535', '536', '537', '538', - '539', '540', '541', '542', '543', '544', '545', '546', '547', '548', - '549', '550', '551', '552', '553', '554', '555', '556', '557', '558', - '559', '560', '561', '562', '563', '564', '565', '566', '567', '568', - '569', '570', '571', '572', '573', '574', '575', '576', '577', '578', - '579', '580', '581', '582', '583', '584', '585', '586', '587', '588', - '589', '590', '591', '592', '593', '594', '595', '596', '597', '598', - '599', '600', '601', '602', '603', '604', '605', '606', '607', '608', - '609', '610', '611', '612', '613', '614', '615', '616', '617', '618', - '619', '620', '621', '622', '623', '624', '625', '626', '627', '628', - '629', '630', '631', '632', '633', '634', '635', '636', '637', '638', - '639', '640', '641', '642', '643', '644', '645', '646', '647', '648', - '649', '650', '651', '652', '653', '654', '655', '656', '657', '658', - '659', '660', '661', '662', '663', '664', '665', '666', '667', '668', - '669', '670', '671', '672', '673', '674', '675', '676', '677', '678', - '679', '680', '681', '682', '683', '684', '685', '686', '687', '688', - '689', '690', '691', '692', '693', '694', '695', '696', '697', '698', - '699', '700', '701', '702', '703', '704', '705', '706', '707', '708', - '709', '710', '711', '712', '713', '714', '715', '716', '717', '718', - '719', '720', '721', '722', '723', '724', '725', '726', '727', '728', - '729', '730', '731', '732', '733', '734', '735', '736', '737', '738', - '739', '740', '741', '742', '743', '744', '745', '746', '747', '748', - '749', '750', '751', '752', '753', '754', '755', '756', '757', '758', - '759', '760', '761', '762', '763', '764', '765', '766', '767', '768', - '769', '770', '771', '772', '773', '774', '775', '776', '777', '778', - '779', '780', '781', '782', '783', '784', '785', '786', '787', '788', - '789', '790', '791', '792', '793', '794', '795', '796', '797', '798', - '799', '800', '801', '802', '803', '804', '805', '806', '807', '808', - '809', '810', '811', '812', '813', '814', '815', '816', '817', '818', - '819', '820', '821', '822', '823', '824', '825', '826', '827', '828', - '829', '830', '831', '832', '833', '834', '835', '836', '837', '838', - '839', '840', '841', '842', '843', '844', '845', '846', '847', '848', - '849', '850', '851', '852', '853', '854', '855', '856', '857', '858', - '859', '860', '861', '862', '863', '864', '865', '866', '867', '868', - '869', '870', '871', '872', '873', '874', '875', '876', '877', '878', - '879', '880', '881', '882', '883', '884', '885', '886', '887', '888', - '889', '890', '891', '892', '893', '894', '895', '896', '897', '898', - '899', '900', '901', '902', '903', '904', '905', '906', '907', '908', - '909', '910', '911', '912', '913', '914', '915', '916', '917', '918', - '919', '920', '921', '922', '923', '924', '925', '926', '927', '928', - '929', '930', '931', '932', '933', '934', '935', '936', '937', '938', - '939', '940', '941', '942', '943', '944', '945', '946', '947', '948', - '949', '950', '951', '952', '953', '954', '955', '956', '957', '958', - '959', '960', '961', '962', '963', '964', '965', '966', '967', '968', - '969', '970', '971', '972', '973', '974', '975', '976', '977', '978', - '979', '980', '981', '982', '983', '984', '985', '986', '987', '988', - '989', '990', '991', '992', '993', '994', '995', '996', '997', '998', '999' -] - -log_config = dict( - interval=10, - hooks=[dict(type='TextLoggerHook'), - dict(type='TensorboardLoggerHook')]) +_base_ = '../common/dataset/imagenet_classification.py' # model settings model = dict( @@ -123,63 +18,6 @@ ), num_classes=1000)) -data_train_list = 'data/imagenet_raw/meta/train_labeled.txt' -data_train_root = 'data/imagenet_raw/train/' -data_test_list = 'data/imagenet_raw/meta/val_labeled.txt' -data_test_root = 'data/imagenet_raw/validation/' -data_all_list = 'data/imagenet_raw/meta/all_labeled.txt' -data_root = 'data/imagenet_raw/' - -dataset_type = 'ClsDataset' -img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) -train_pipeline = [ - dict(type='RandomResizedCrop', size=224), - dict(type='RandomHorizontalFlip'), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Collect', keys=['img', 'gt_labels']) -] -test_pipeline = [ - dict(type='Resize', size=256), - dict(type='CenterCrop', size=224), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Collect', keys=['img', 'gt_labels']) -] - -data = dict( - imgs_per_gpu=32, # total 256 - workers_per_gpu=4, - train=dict( - type=dataset_type, - data_source=dict( - list_file=data_train_list, - root=data_train_root, - type='ClsSourceImageList'), - pipeline=train_pipeline), - val=dict( - type=dataset_type, - data_source=dict( - list_file=data_test_list, - root=data_test_root, - type='ClsSourceImageList'), - pipeline=test_pipeline)) - -eval_config = dict(initial=False, interval=1, gpu_collect=True) -eval_pipelines = [ - dict( - mode='test', - data=data['val'], - dist_eval=True, - evaluators=[ - dict(type='ClsEvaluator', topk=(1, 5), class_list=class_list) - ], - ) -] - -# additional hooks -custom_hooks = [] - # optimizer optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001) @@ -189,13 +27,3 @@ # runtime settings total_epochs = 100 - -predict = dict( - type='ClassificationPredictor', - pipelines=[ - dict(type='Resize', size=256), - dict(type='CenterCrop', size=224), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Collect', keys=['img']) - ]) diff --git a/configs/classification/imagenet/resnext/resnext50-32x4d_b32x8_100e_jpg.py b/configs/classification/imagenet/resnext/resnext50-32x4d_b32x8_100e_jpg.py index 22f0f7c1..adb9ee7f 100644 --- a/configs/classification/imagenet/resnext/resnext50-32x4d_b32x8_100e_jpg.py +++ b/configs/classification/imagenet/resnext/resnext50-32x4d_b32x8_100e_jpg.py @@ -1,8 +1,4 @@ -_base_ = 'configs/base.py' -log_config = dict( - interval=10, - hooks=[dict(type='TextLoggerHook'), - dict(type='TensorboardLoggerHook')]) +_base_ = '../common/dataset/imagenet_classification.py' # model settings model = dict( @@ -24,61 +20,6 @@ ), num_classes=1000)) -data_train_list = 'data/imagenet_raw/meta/train_labeled.txt' -data_train_root = 'data/imagenet_raw/train/' -data_test_list = 'data/imagenet_raw/meta/val_labeled.txt' -data_test_root = 'data/imagenet_raw/val/' -data_all_list = 'data/imagenet_raw/meta/all_labeled.txt' -data_root = 'data/imagenet_raw/' - -dataset_type = 'ClsDataset' -img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) -train_pipeline = [ - dict(type='RandomResizedCrop', size=224), - dict(type='RandomHorizontalFlip'), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Collect', keys=['img', 'gt_labels']) -] -test_pipeline = [ - dict(type='Resize', size=256), - dict(type='CenterCrop', size=224), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Collect', keys=['img', 'gt_labels']) -] - -data = dict( - imgs_per_gpu=32, # total 256 - workers_per_gpu=4, - train=dict( - type=dataset_type, - data_source=dict( - list_file=data_train_list, - root=data_train_root, - type='ClsSourceImageList'), - pipeline=train_pipeline), - val=dict( - type=dataset_type, - data_source=dict( - list_file=data_test_list, - root=data_test_root, - type='ClsSourceImageList'), - pipeline=test_pipeline)) - -eval_config = dict(initial=True, interval=1, gpu_collect=True) -eval_pipelines = [ - dict( - mode='test', - data=data['val'], - dist_eval=True, - evaluators=[dict(type='ClsEvaluator', topk=(1, 5))], - ) -] - -# additional hooks -custom_hooks = [] - # optimizer optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001) diff --git a/configs/classification/imagenet/swint/swin_tiny_patch4_window7_224_b64x16_300e_jpg.py b/configs/classification/imagenet/swint/swin_tiny_patch4_window7_224_b64x16_300e_jpg.py index b7381c5d..63ffdbf2 100644 --- a/configs/classification/imagenet/swint/swin_tiny_patch4_window7_224_b64x16_300e_jpg.py +++ b/configs/classification/imagenet/swint/swin_tiny_patch4_window7_224_b64x16_300e_jpg.py @@ -1,15 +1,9 @@ -_base_ = 'configs/base.py' - -log_config = dict( - interval=10, - hooks=[dict(type='TextLoggerHook'), - dict(type='TensorboardLoggerHook')]) +_base_ = '../common/dataset/imagenet_classification.py' # model settings model = dict( type='Classification', train_preprocess=['mixUp'], - pretrained=False, mixup_cfg=dict( mixup_alpha=0.8, cutmix_alpha=1.0, @@ -29,17 +23,11 @@ }, with_fc=False)) -data_train_list = 'data/imagenet_raw/meta/train_labeled.txt' -data_train_root = 'data/imagenet_raw/train/' -data_test_list = 'data/imagenet_raw/meta/val_labeled.txt' -data_test_root = 'data/imagenet_raw/val/' -data_all_list = 'data/imagenet_raw/meta/all_labeled.txt' -data_root = 'data/imagenet_raw/' - -dataset_type = 'ClsDataset' +image_size2 = 224 img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + train_pipeline = [ - dict(type='RandomResizedCrop', size=224), + dict(type='RandomResizedCrop', size=image_size2), dict(type='RandomHorizontalFlip'), dict( type='MMRandAugment', @@ -62,44 +50,11 @@ dict(type='Normalize', **img_norm_cfg), dict(type='Collect', keys=['img', 'gt_labels']) ] -test_pipeline = [ - dict(type='Resize', size=256), - dict(type='CenterCrop', size=224), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Collect', keys=['img', 'gt_labels']) -] data = dict( imgs_per_gpu=64, # total 256 workers_per_gpu=8, - train=dict( - type=dataset_type, - data_source=dict( - list_file=data_train_list, - root=data_train_root, - type='ClsSourceImageList'), - pipeline=train_pipeline), - val=dict( - type=dataset_type, - data_source=dict( - list_file=data_test_list, - root=data_test_root, - type='ClsSourceImageList'), - pipeline=test_pipeline)) - -eval_config = dict(initial=True, interval=1, gpu_collect=True) -eval_pipelines = [ - dict( - mode='test', - data=data['val'], - dist_eval=True, - evaluators=[dict(type='ClsEvaluator', topk=(1, 5))], - ) -] - -# additional hooks -custom_hooks = [] + train=dict(pipeline=train_pipeline)) # optimizer paramwise_options = { diff --git a/configs/classification/imagenet/vit/vit_base_patch16_224_b64x64_300e_jpg.py b/configs/classification/imagenet/vit/vit_base_patch16_224_b64x64_300e_jpg.py index 1b1ca32b..94e0d7ca 100644 --- a/configs/classification/imagenet/vit/vit_base_patch16_224_b64x64_300e_jpg.py +++ b/configs/classification/imagenet/vit/vit_base_patch16_224_b64x64_300e_jpg.py @@ -1,15 +1,9 @@ -_base_ = 'configs/base.py' - -log_config = dict( - interval=10, - hooks=[dict(type='TextLoggerHook'), - dict(type='TensorboardLoggerHook')]) +_base_ = '../common/dataset/imagenet_classification.py' # model settings model = dict( type='Classification', train_preprocess=['mixUp'], - pretrained=False, mixup_cfg=dict( mixup_alpha=0.2, prob=1.0, @@ -28,17 +22,12 @@ }, with_fc=False)) -data_train_list = 'data/imagenet_raw/meta/train_labeled.txt' -data_train_root = 'data/imagenet_raw/train/' -data_test_list = 'data/imagenet_raw/meta/val_labeled.txt' -data_test_root = 'data/imagenet_raw/val/' -data_all_list = 'data/imagenet_raw/meta/all_labeled.txt' -data_root = 'data/imagenet_raw/' - -dataset_type = 'ClsDataset' +image_size2 = 224 +image_size1 = int((256 / 224) * image_size2) img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + train_pipeline = [ - dict(type='RandomResizedCrop', size=224), + dict(type='RandomResizedCrop', size=image_size2), dict(type='RandomHorizontalFlip'), dict(type='MMAutoAugment'), dict(type='ToTensor'), @@ -46,8 +35,8 @@ dict(type='Collect', keys=['img', 'gt_labels']) ] test_pipeline = [ - dict(type='Resize', size=256), - dict(type='CenterCrop', size=224), + dict(type='Resize', size=image_size1), + dict(type='CenterCrop', size=image_size2), dict(type='ToTensor'), dict(type='Normalize', **img_norm_cfg), dict(type='Collect', keys=['img', 'gt_labels']) @@ -56,33 +45,8 @@ data = dict( imgs_per_gpu=64, # total 256 workers_per_gpu=8, - train=dict( - type=dataset_type, - data_source=dict( - list_file=data_train_list, - root=data_train_root, - type='ClsSourceImageList'), - pipeline=train_pipeline), - val=dict( - type=dataset_type, - data_source=dict( - list_file=data_test_list, - root=data_test_root, - type='ClsSourceImageList'), - pipeline=test_pipeline)) - -eval_config = dict(initial=True, interval=1, gpu_collect=True) -eval_pipelines = [ - dict( - mode='test', - data=data['val'], - dist_eval=True, - evaluators=[dict(type='ClsEvaluator', topk=(1, 5))], - ) -] - -# additional hooks -custom_hooks = [] + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline)) # optimizer optimizer = dict( diff --git a/configs/metric_learning/common/dataset/imagenet_metriclearning.py b/configs/metric_learning/common/dataset/imagenet_metriclearning.py new file mode 100644 index 00000000..1f6d5049 --- /dev/null +++ b/configs/metric_learning/common/dataset/imagenet_metriclearning.py @@ -0,0 +1,167 @@ +_base_ = '../metriclearning_base.py' + +class_list = [ + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', + '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', + '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', + '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', + '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', + '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', + '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', + '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', + '98', '99', '100', '101', '102', '103', '104', '105', '106', '107', '108', + '109', '110', '111', '112', '113', '114', '115', '116', '117', '118', + '119', '120', '121', '122', '123', '124', '125', '126', '127', '128', + '129', '130', '131', '132', '133', '134', '135', '136', '137', '138', + '139', '140', '141', '142', '143', '144', '145', '146', '147', '148', + '149', '150', '151', '152', '153', '154', '155', '156', '157', '158', + '159', '160', '161', '162', '163', '164', '165', '166', '167', '168', + '169', '170', '171', '172', '173', '174', '175', '176', '177', '178', + '179', '180', '181', '182', '183', '184', '185', '186', '187', '188', + '189', '190', '191', '192', '193', '194', '195', '196', '197', '198', + '199', '200', '201', '202', '203', '204', '205', '206', '207', '208', + '209', '210', '211', '212', '213', '214', '215', '216', '217', '218', + '219', '220', '221', '222', '223', '224', '225', '226', '227', '228', + '229', '230', '231', '232', '233', '234', '235', '236', '237', '238', + '239', '240', '241', '242', '243', '244', '245', '246', '247', '248', + '249', '250', '251', '252', '253', '254', '255', '256', '257', '258', + '259', '260', '261', '262', '263', '264', '265', '266', '267', '268', + '269', '270', '271', '272', '273', '274', '275', '276', '277', '278', + '279', '280', '281', '282', '283', '284', '285', '286', '287', '288', + '289', '290', '291', '292', '293', '294', '295', '296', '297', '298', + '299', '300', '301', '302', '303', '304', '305', '306', '307', '308', + '309', '310', '311', '312', '313', '314', '315', '316', '317', '318', + '319', '320', '321', '322', '323', '324', '325', '326', '327', '328', + '329', '330', '331', '332', '333', '334', '335', '336', '337', '338', + '339', '340', '341', '342', '343', '344', '345', '346', '347', '348', + '349', '350', '351', '352', '353', '354', '355', '356', '357', '358', + '359', '360', '361', '362', '363', '364', '365', '366', '367', '368', + '369', '370', '371', '372', '373', '374', '375', '376', '377', '378', + '379', '380', '381', '382', '383', '384', '385', '386', '387', '388', + '389', '390', '391', '392', '393', '394', '395', '396', '397', '398', + '399', '400', '401', '402', '403', '404', '405', '406', '407', '408', + '409', '410', '411', '412', '413', '414', '415', '416', '417', '418', + '419', '420', '421', '422', '423', '424', '425', '426', '427', '428', + '429', '430', '431', '432', '433', '434', '435', '436', '437', '438', + '439', '440', '441', '442', '443', '444', '445', '446', '447', '448', + '449', '450', '451', '452', '453', '454', '455', '456', '457', '458', + '459', '460', '461', '462', '463', '464', '465', '466', '467', '468', + '469', '470', '471', '472', '473', '474', '475', '476', '477', '478', + '479', '480', '481', '482', '483', '484', '485', '486', '487', '488', + '489', '490', '491', '492', '493', '494', '495', '496', '497', '498', + '499', '500', '501', '502', '503', '504', '505', '506', '507', '508', + '509', '510', '511', '512', '513', '514', '515', '516', '517', '518', + '519', '520', '521', '522', '523', '524', '525', '526', '527', '528', + '529', '530', '531', '532', '533', '534', '535', '536', '537', '538', + '539', '540', '541', '542', '543', '544', '545', '546', '547', '548', + '549', '550', '551', '552', '553', '554', '555', '556', '557', '558', + '559', '560', '561', '562', '563', '564', '565', '566', '567', '568', + '569', '570', '571', '572', '573', '574', '575', '576', '577', '578', + '579', '580', '581', '582', '583', '584', '585', '586', '587', '588', + '589', '590', '591', '592', '593', '594', '595', '596', '597', '598', + '599', '600', '601', '602', '603', '604', '605', '606', '607', '608', + '609', '610', '611', '612', '613', '614', '615', '616', '617', '618', + '619', '620', '621', '622', '623', '624', '625', '626', '627', '628', + '629', '630', '631', '632', '633', '634', '635', '636', '637', '638', + '639', '640', '641', '642', '643', '644', '645', '646', '647', '648', + '649', '650', '651', '652', '653', '654', '655', '656', '657', '658', + '659', '660', '661', '662', '663', '664', '665', '666', '667', '668', + '669', '670', '671', '672', '673', '674', '675', '676', '677', '678', + '679', '680', '681', '682', '683', '684', '685', '686', '687', '688', + '689', '690', '691', '692', '693', '694', '695', '696', '697', '698', + '699', '700', '701', '702', '703', '704', '705', '706', '707', '708', + '709', '710', '711', '712', '713', '714', '715', '716', '717', '718', + '719', '720', '721', '722', '723', '724', '725', '726', '727', '728', + '729', '730', '731', '732', '733', '734', '735', '736', '737', '738', + '739', '740', '741', '742', '743', '744', '745', '746', '747', '748', + '749', '750', '751', '752', '753', '754', '755', '756', '757', '758', + '759', '760', '761', '762', '763', '764', '765', '766', '767', '768', + '769', '770', '771', '772', '773', '774', '775', '776', '777', '778', + '779', '780', '781', '782', '783', '784', '785', '786', '787', '788', + '789', '790', '791', '792', '793', '794', '795', '796', '797', '798', + '799', '800', '801', '802', '803', '804', '805', '806', '807', '808', + '809', '810', '811', '812', '813', '814', '815', '816', '817', '818', + '819', '820', '821', '822', '823', '824', '825', '826', '827', '828', + '829', '830', '831', '832', '833', '834', '835', '836', '837', '838', + '839', '840', '841', '842', '843', '844', '845', '846', '847', '848', + '849', '850', '851', '852', '853', '854', '855', '856', '857', '858', + '859', '860', '861', '862', '863', '864', '865', '866', '867', '868', + '869', '870', '871', '872', '873', '874', '875', '876', '877', '878', + '879', '880', '881', '882', '883', '884', '885', '886', '887', '888', + '889', '890', '891', '892', '893', '894', '895', '896', '897', '898', + '899', '900', '901', '902', '903', '904', '905', '906', '907', '908', + '909', '910', '911', '912', '913', '914', '915', '916', '917', '918', + '919', '920', '921', '922', '923', '924', '925', '926', '927', '928', + '929', '930', '931', '932', '933', '934', '935', '936', '937', '938', + '939', '940', '941', '942', '943', '944', '945', '946', '947', '948', + '949', '950', '951', '952', '953', '954', '955', '956', '957', '958', + '959', '960', '961', '962', '963', '964', '965', '966', '967', '968', + '969', '970', '971', '972', '973', '974', '975', '976', '977', '978', + '979', '980', '981', '982', '983', '984', '985', '986', '987', '988', + '989', '990', '991', '992', '993', '994', '995', '996', '997', '998', '999' +] + +data_source_type = 'ClsSourceImageList' +data_train_list = '/apsarapangu/disk1/yunji.cjy/data/imagenet_raw/meta/train_labeled.txt' +data_train_root = '/apsarapangu/disk1/yunji.cjy/data/imagenet_raw/train/' +data_test_list = '/apsarapangu/disk1/yunji.cjy/data/imagenet_raw/meta/val_labeled.txt' +data_test_root = '/apsarapangu/disk1/yunji.cjy/data/imagenet_raw/validation/' +dataset_type = 'ClsDataset' +img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +image_size2 = 224 +image_size1 = int((256 / 224) * image_size2) + +train_pipeline = [ + dict(type='RandomResizedCrop', size=image_size2), + dict(type='RandomHorizontalFlip'), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Collect', keys=['img', 'gt_labels']) +] +test_pipeline = [ + dict(type='Resize', size=image_size1), + dict(type='CenterCrop', size=image_size2), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Collect', keys=['img', 'gt_labels']) +] + +data = dict( + imgs_per_gpu=32, # total 256 + workers_per_gpu=4, + drop_last=True, + train=dict( + type=dataset_type, + data_source=dict( + list_file=data_train_list, + root=data_train_root, + type=data_source_type), + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_source=dict( + list_file=data_test_list, + root=data_test_root, + type=data_source_type), + pipeline=test_pipeline)) + +eval_config = dict(initial=False, interval=1, gpu_collect=True) +eval_pipelines = [ + dict( + mode='test', + data=data['val'], + dist_eval=True, + evaluators=[dict(type='ClsEvaluator', topk=(1, 5))], + ), + dict( + mode='extract', + dist_eval=True, + data=data['val'], + evaluators=[ + dict( + type='RetrivalTopKEvaluator', + topk=(1, 2, 4, 8), + metric_names=('R@K=1', 'R@K=8')) + ], + ) +] diff --git a/configs/metric_learning/common/metriclearning_base.py b/configs/metric_learning/common/metriclearning_base.py new file mode 100644 index 00000000..0cc5fb74 --- /dev/null +++ b/configs/metric_learning/common/metriclearning_base.py @@ -0,0 +1,6 @@ +_base_ = 'configs/base.py' + +log_config = dict( + interval=10, + hooks=[dict(type='TextLoggerHook'), + dict(type='TensorboardLoggerHook')]) diff --git a/configs/metric_learning/imagenet_timm_modelparallel_softmaxbased_jpg.py b/configs/metric_learning/imagenet_timm_modelparallel_softmaxbased_jpg.py new file mode 100644 index 00000000..c6c3f4d2 --- /dev/null +++ b/configs/metric_learning/imagenet_timm_modelparallel_softmaxbased_jpg.py @@ -0,0 +1,39 @@ +_base_ = './imagenet_timm_softmaxbased_jpg.py' + +backbone_channels = 2048 +feature_channels = 1536 +num_classes = 300 + +metric_loss_name = 'ModelParallelSoftmaxLoss' +metric_loss_scale = 30 +metric_loss_margin = 0.4 + +# model settings +model = dict( + _delete_=True, + type='Classification', + backbone=dict(type='PytorchImageModelWrapper', model_name='resnet50'), + neck=dict( + type='RetrivalNeck', + in_channels=backbone_channels, + out_channels=feature_channels, + with_avg_pool=True, + cdg_config=['G', 'S']), + head=[ + dict( + type='MpMetrixHead', + with_avg_pool=True, + in_channels=feature_channels, + loss_config=[ + dict( + type=metric_loss_name, + loss_weight=1.0, + norm=False, + ddp=True, + scale=30, + margin=0.4, + embedding_size=feature_channels, + num_classes=num_classes) + ], + input_feature_index=[0]) + ]) diff --git a/configs/metric_learning/imagenet_timm_softmaxbased_jpg.py b/configs/metric_learning/imagenet_timm_softmaxbased_jpg.py new file mode 100644 index 00000000..62335a4a --- /dev/null +++ b/configs/metric_learning/imagenet_timm_softmaxbased_jpg.py @@ -0,0 +1,72 @@ +_base_ = 'common/dataset/imagenet_metriclearning.py' + +backbone_channels = 2048 +feature_channels = 1536 +num_classes = 300 + +metric_loss_name = 'AMSoftmaxLoss' +metric_loss_scale = 30 +metric_loss_margin = 0.4 + +# model settings +model = dict( + type='Classification', + backbone=dict(type='PytorchImageModelWrapper', model_name='resnet50'), + neck=dict( + type='RetrivalNeck', + in_channels=backbone_channels, + out_channels=feature_channels, + with_avg_pool=True, + cdg_config=['G', 'S']), + head=[ + dict( + type='MpMetrixHead', + with_avg_pool=True, + in_channels=feature_channels, + loss_config=[ + dict( + type='CrossEntropyLossWithLabelSmooth', + loss_weight=1.0, + norm=True, + ddp=False, + label_smooth=0.1, + temperature=0.05, + with_cls=True, + embedding_size=feature_channels, + num_classes=num_classes) + ], + input_feature_index=[1]), + dict( + type='MpMetrixHead', + with_avg_pool=True, + in_channels=feature_channels, + loss_config=[ + dict( + type=metric_loss_name, + loss_weight=1.0, + norm=False, + ddp=False, + scale=metric_loss_scale, + margin=metric_loss_margin, + embedding_size=feature_channels, + num_classes=num_classes) + ], + input_feature_index=[0]) + ]) + +# optimizer +optimizer = dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0001) + +# learning policy +lr_config = dict( + policy='CosineAnnealing', + min_lr=1e-6, + warmup='linear', + warmup_iters=100, + warmup_ratio=0.0001) + +checkpoint_config = dict(interval=5) +# runtime settings +total_epochs = 100 + +find_unused_parameters = True diff --git a/easycv/apis/export.py b/easycv/apis/export.py index 6c380008..0f6c1085 100644 --- a/easycv/apis/export.py +++ b/easycv/apis/export.py @@ -34,6 +34,12 @@ def export(cfg, ckpt_path, filename, **kwargs): ckpt_path (str): path to checkpoint file filename (str): filename to save exported models """ + + logging.warning( + 'Export needs to set pretrained to false to avoid hanging during distributed training' + ) + cfg.model['pretrained'] = False + model = build_model(cfg.model) if ckpt_path != 'dummy': load_checkpoint(model, ckpt_path, map_location='cpu') diff --git a/easycv/core/evaluation/classification_eval.py b/easycv/core/evaluation/classification_eval.py index c60c6ede..5b469d43 100644 --- a/easycv/core/evaluation/classification_eval.py +++ b/easycv/core/evaluation/classification_eval.py @@ -66,6 +66,11 @@ def _evaluate_impl(self, predictions, gt_labels): num = scores.size(0) _, pred = scores.topk( max(self._topk), dim=1, largest=True, sorted=True) + + # Avoid topk values greater than the number of categories + self._topk = np.array(list(self._topk)) + self._topk = np.clip(self._topk, 1, scores.shape[-1]) + pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) # KxN for k in self._topk: diff --git a/easycv/datasets/classification/data_sources/__init__.py b/easycv/datasets/classification/data_sources/__init__.py index 1427be6f..b3b9f7d9 100644 --- a/easycv/datasets/classification/data_sources/__init__.py +++ b/easycv/datasets/classification/data_sources/__init__.py @@ -2,12 +2,12 @@ from .cifar import ClsSourceCifar10, ClsSourceCifar100 from .class_list import ClsSourceImageListByClass from .cub import ClsSourceCUB -from .image_list import ClsSourceImageList +from .image_list import ClsSourceImageList, ClsSourceItag from .imagenet import ClsSourceImageNet1k from .imagenet_tfrecord import ClsSourceImageNetTFRecord __all__ = [ 'ClsSourceCifar10', 'ClsSourceCifar100', 'ClsSourceImageListByClass', - 'ClsSourceImageList', 'ClsSourceImageNetTFRecord', 'ClsSourceCUB', - 'ClsSourceImageNet1k' + 'ClsSourceImageList', 'ClsSourceItag', 'ClsSourceImageNetTFRecord', + 'ClsSourceCUB', 'ClsSourceImageNet1k' ] diff --git a/easycv/datasets/classification/data_sources/image_list.py b/easycv/datasets/classification/data_sources/image_list.py index e37f9fa8..00d34a3a 100644 --- a/easycv/datasets/classification/data_sources/image_list.py +++ b/easycv/datasets/classification/data_sources/image_list.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import json import logging import os import time @@ -89,7 +90,10 @@ def parse_list_file(list_file, root, delimeter): for l in lines: splits = l.strip().split(delimeter) - fns.append(os.path.join(root, splits[0])) + if len(root) > 0: + fns.append(os.path.join(root, splits[0])) + else: + fns.append(splits[0]) # must be int,other with mmcv collect will crash label = [int(i) for i in splits[1:]] labels.append( @@ -124,3 +128,118 @@ def __getitem__(self, idx): result_dict = {'img': img, 'gt_labels': label} return result_dict + + +@DATASOURCES.register_module +class ClsSourceItag(ClsSourceImageList): + """ data source itag for classification + + Args: + list_file : str / list(str), str means a input image list file path, + this file contains records as `image_path label` in list_file + list(str) means multi image list, each one contains some records as `image_path label` + root: str / list(str), root path for image_path, each list_file will need a root, + if len(root) < len(list_file), we will use root[-1] to fill root list. + delimeter: str, delimeter of each line in the `list_file` + split_huge_listfile_byrank: Adapt to the situation that the memory cannot fully load a huge amount of data list. + If split, data list will be split to each rank. + split_label_balance: if `split_huge_listfile_byrank` is true, whether split with label balance + cache_path: if `split_huge_listfile_byrank` is true, cache list_file will be saved to cache_path. + max_try: int, max try numbers of reading image + """ + + def __init__(self, + list_file, + root='', + delimeter=' ', + split_huge_listfile_byrank=False, + split_label_balance=False, + cache_path='data/', + max_try=20): + + ImageFile.LOAD_TRUNCATED_IMAGES = True + + self.max_try = max_try + + # DistributedMPSampler need this attr + self.has_labels = True + + if isinstance(list_file, str): + assert isinstance(root, str), 'list_file is str, root must be str' + list_file = [list_file] + root = [root] + else: + assert isinstance(list_file, list), \ + 'list_file should be str or list(str)' + root = [root] if isinstance(root, str) else root + if not isinstance(root, list): + raise TypeError('root must be str or list(str), but get %s' % + type(root)) + + if len(root) < len(list_file): + logging.warning( + 'len(root) < len(list_file), fill root with root last!') + root = root + [root[-1]] * (len(list_file) - len(root)) + + # TODO: support return list, donot save split file + # TODO: support loading list_file that have already been split + if split_huge_listfile_byrank: + with dist_zero_exec(): + list_file = split_listfile_byrank( + list_file=list_file, + label_balance=split_label_balance, + save_path=cache_path) + + self.fns = [] + self.labels = [] + label_dict = dict() + for l, r in zip(list_file, root): + fns, labels, label_dict = self.parse_list_file(l, label_dict) + self.fns += fns + self.labels += labels + + @staticmethod + def parse_list_file(list_file, label_dict): + with open(list_file, 'r', encoding='utf-8') as f: + data = f.readlines() + + fns = [] + labels = [] + for i in range(len(data)): + data_i = json.loads(data[i]) + img_path = data_i['data']['source'] + label = [] + + priority = 2 + for k in data_i.keys(): + if 'verify' in k: + priority = 0 + break + elif 'check' in k: + priority = 1 + + for k, v in data_i.items(): + if 'label' in k: + label = [] + result_list = v['results'] + for j in range(len(result_list)): + anno_list = result_list[j]['data'] + if 'labels' in anno_list: + if anno_list['labels']['单选'] not in label_dict: + label_dict[anno_list['labels']['单选']] = len( + label_dict) + label.append(label_dict[anno_list['labels']['单选']]) + else: + if anno_list not in label_dict: + label_dict[anno_list] = len(label_dict) + label.append(label_dict[anno_list]) + if 'verify' in k: + break + elif 'check' in k and priority == 1: + break + + fns.append(img_path) + labels.append( + label[0]) if len(label) == 1 else labels.append(label) + + return fns, labels, label_dict diff --git a/easycv/models/classification/classification.py b/easycv/models/classification/classification.py index ccd30d50..57c32c08 100644 --- a/easycv/models/classification/classification.py +++ b/easycv/models/classification/classification.py @@ -120,7 +120,10 @@ def init_weights(self): self.backbone, self.backbone.default_pretrained_model_path, strict=False, - logger=logger) + logger=logger, + revise_keys=[ + (r'^backbone\.', '') + ]) # revise_keys is used to avoid load mismatches else: raise ValueError( 'default_pretrained_model_path for {} not found'.format( diff --git a/easycv/models/modelzoo.py b/easycv/models/modelzoo.py index 58f005c4..d367cc62 100644 --- a/easycv/models/modelzoo.py +++ b/easycv/models/modelzoo.py @@ -40,6 +40,25 @@ 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw64/epoch_100.pth', } +vit = { + 'vit-base': + 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/vit/vit-base-patch16/epoch_300.pth', +} + +swint = { + 'swint-tiny': + 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/swint/swin-tiny-patch4-window7/epoch_300.pth', +} + +deit = { + 'deitiii-small': + 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/deitiii/imagenet_deitiii_small_patch16_224/deitiii_small.pth', + 'deitiii-base': + 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/deitiii/imagenet_deitiii_base_patch16_192/deitiii_base.pth', + 'deitiii-large': + 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/deitiii/imagenet_deitiii_large_patch16_192/deitiii_large.pth', +} + mobilenetv2 = { 'MobileNetV2_1.0': 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/mobilenetv2/mobilenet_v2.pth', diff --git a/easycv/models/utils/multi_pooling.py b/easycv/models/utils/multi_pooling.py index 6ca1be89..75304d9a 100644 --- a/easycv/models/utils/multi_pooling.py +++ b/easycv/models/utils/multi_pooling.py @@ -23,10 +23,6 @@ def gem(self, x, p=3, eps=1e-6): return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p) - def __repr__(self): - return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format( - self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' - class MultiPooling(nn.Module): """Pooling layers for features from multiple depth. diff --git a/easycv/toolkit/hpo/det/config_dlc.ini b/easycv/toolkit/hpo/det/config_dlc.ini index dd2b2a9c..188d6b2b 100644 --- a/easycv/toolkit/hpo/det/config_dlc.ini +++ b/easycv/toolkit/hpo/det/config_dlc.ini @@ -9,7 +9,7 @@ cmd2="dlc submit pytorch --name=test_nni_${exp_id}_${trial_id} \ --data_sources='d-domlyt834bngpr68iu' \ --worker_image=registry-vpc.cn-shanghai.aliyuncs.com/mybigpai/nni:0.0.3 \ --command='cd /mnt/data/EasyCV && pip install mmcv-full && pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple \ - && CUDA_VISIBLE_DEVICES=0,1,2,3,4 python -m torch.distributed.launch --nproc_per_node=4 --master_port=29400 tools/train.py easycv/toolkit/hpo/det/fcos_r50_torch_1x_coco.py --work_dir easycv/toolkit/hpo/det/model/model_${exp_id}_${trial_id} --launcher pytorch --seed 42 --deterministic --user_config_params --data_root /mnt/data/coco/ --data.imgs_per_gpu ${batch_size} --optimizer.lr ${lr} ' \ + && CUDA_VISIBLE_DEVICES=0,1,2,3,4 python -m torch.distributed.launch --nproc_per_node=4 --master_port=29400 tools/train.py configs/detection/fcos/fcos_r50_torch_1x_coco.py --work_dir easycv/toolkit/hpo/det/model/model_${exp_id}_${trial_id} --launcher pytorch --seed 42 --deterministic --user_config_params data_root='/mnt/data/coco/' data.imgs_per_gpu=${batch_size} optimizer.lr=${lr} ' \ --workspace_id='255705' " [metric_config] diff --git a/easycv/toolkit/hpo/det/config_local.ini b/easycv/toolkit/hpo/det/config_local.ini index 189a87ea..1f524f80 100644 --- a/easycv/toolkit/hpo/det/config_local.ini +++ b/easycv/toolkit/hpo/det/config_local.ini @@ -1,5 +1,5 @@ [cmd_config] -cmd1='cd /mnt/data/EasyCV && CUDA_VISIBLE_DEVICES=0,1,2,3,4 python -m torch.distributed.launch --nproc_per_node=4 --master_port=29400 tools/train.py easycv/toolkit/hpo/det/fcos_r50_torch_1x_coco.py --work_dir easycv/toolkit/hpo/det/model/model_${exp_id}_${trial_id} --launcher pytorch --seed 42 --deterministic --user_config_params --data_root /mnt/data/coco/ --data.imgs_per_gpu ${batch_size} --optimizer.lr ${lr} ' +cmd1='cd /mnt/data/EasyCV && CUDA_VISIBLE_DEVICES=0,1,2,3,4 python -m torch.distributed.launch --nproc_per_node=4 --master_port=29400 tools/train.py configs/detection/fcos/fcos_r50_torch_1x_coco.py --work_dir easycv/toolkit/hpo/det/model/model_${exp_id}_${trial_id} --launcher pytorch --seed 42 --deterministic --user_config_params data_root='/mnt/data/coco/' data.imgs_per_gpu=${batch_size} optimizer.lr=${lr} ' [metric_config] metric_filepath=easycv/toolkit/hpo/det/model/model_${exp_id}_${trial_id}/tf_logs diff --git a/easycv/toolkit/hpo/det/fcos_r50_torch_1x_coco.py b/easycv/toolkit/hpo/det/fcos_r50_torch_1x_coco.py deleted file mode 100644 index 402b11d9..00000000 --- a/easycv/toolkit/hpo/det/fcos_r50_torch_1x_coco.py +++ /dev/null @@ -1,192 +0,0 @@ -train_cfg = {} -test_cfg = {} -optimizer_config = dict() # grad_clip, coalesce, bucket_size_mb -# yapf:disable -log_config = dict( - interval=50, - hooks=[ - dict(type='TextLoggerHook'), - dict(type='TensorboardLoggerHook') - ]) -# yapf:enable -# runtime settings -dist_params = dict(backend='nccl') -cudnn_benchmark = False -log_level = 'INFO' -load_from = None -resume_from = None -workflow = [('train', 1)] - -CLASSES = [ - 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', - 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', - 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', - 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', - 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', - 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', - 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', - 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', - 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', - 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', - 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', - 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', - 'hair drier', 'toothbrush' -] - -# dataset settings -data_root = '/mnt/data/coco/' -img_norm_cfg = dict( - mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) - -train_pipeline = [ - dict(type='MMResize', img_scale=(1333, 800), keep_ratio=True), - dict(type='MMRandomFlip', flip_ratio=0.5), - dict(type='MMNormalize', **img_norm_cfg), - dict(type='MMPad', size_divisor=32), - dict(type='DefaultFormatBundle'), - dict( - type='Collect', - keys=['img', 'gt_bboxes', 'gt_labels'], - meta_keys=('filename', 'ori_filename', 'ori_shape', 'ori_img_shape', - 'img_shape', 'pad_shape', 'scale_factor', 'flip', - 'flip_direction', 'img_norm_cfg')) -] -test_pipeline = [ - dict( - type='MMMultiScaleFlipAug', - img_scale=(1333, 800), - flip=False, - transforms=[ - dict(type='MMResize', keep_ratio=True), - dict(type='MMRandomFlip'), - dict(type='MMNormalize', **img_norm_cfg), - dict(type='MMPad', size_divisor=32), - dict(type='ImageToTensor', keys=['img']), - dict( - type='Collect', - keys=['img'], - meta_keys=('filename', 'ori_filename', 'ori_shape', - 'ori_img_shape', 'img_shape', 'pad_shape', - 'scale_factor', 'flip', 'flip_direction', - 'img_norm_cfg')) - ]) -] - -train_dataset = dict( - type='DetDataset', - data_source=dict( - type='DetSourceCoco', - ann_file='${data_root}' + 'annotations/instances_train2017.json', - img_prefix='${data_root}' + 'train2017/', - pipeline=[ - dict(type='LoadImageFromFile'), - dict(type='LoadAnnotations', with_bbox=True) - ], - classes=CLASSES, - test_mode=False, - filter_empty_gt=True, - iscrowd=False), - pipeline=train_pipeline) - -val_dataset = dict( - type='DetDataset', - imgs_per_gpu=1, - data_source=dict( - type='DetSourceCoco', - ann_file='${data_root}' + 'annotations/instances_val2017.json', - img_prefix='${data_root}' + 'val2017/', - pipeline=[ - dict(type='LoadImageFromFile'), - dict(type='LoadAnnotations', with_bbox=True) - ], - classes=CLASSES, - test_mode=True, - filter_empty_gt=False, - iscrowd=True), - pipeline=test_pipeline) - -data = dict( - imgs_per_gpu=2, workers_per_gpu=2, train=train_dataset, val=val_dataset) - -# evaluation -eval_config = dict(interval=1, gpu_collect=False) -eval_pipelines = [ - dict( - mode='test', - evaluators=[ - dict(type='CocoDetectionEvaluator', classes=CLASSES), - ], - ) -] - -# model settings -model = dict( - type='Detection', - pretrained=True, - backbone=dict( - type='ResNet', - depth=50, - num_stages=4, - out_indices=(1, 2, 3, 4), - frozen_stages=1, - norm_cfg=dict(type='BN', requires_grad=False), - norm_eval=True, - style='pytorch'), - neck=dict( - type='FPN', - in_channels=[256, 512, 1024, 2048], - out_channels=256, - start_level=1, - add_extra_convs='on_output', # use P5 - num_outs=5, - relu_before_extra_convs=True), - head=dict( - type='FCOSHead', - num_classes=80, - in_channels=256, - stacked_convs=4, - feat_channels=256, - strides=[8, 16, 32, 64, 128], - center_sampling=True, - center_sample_radius=1.5, - norm_on_bbox=True, - centerness_on_reg=True, - conv_cfg=None, - loss_cls=dict( - type='FocalLoss', - use_sigmoid=True, - gamma=2.0, - alpha=0.25, - loss_weight=1.0), - loss_bbox=dict(type='GIoULoss', loss_weight=1.0), - loss_centerness=dict( - type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), - norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), - conv_bias=True, - test_cfg=dict( - nms_pre=1000, - min_bbox_size=0, - score_thr=0.05, - nms=dict(type='nms', iou_threshold=0.6), - max_per_img=100))) - -checkpoint_config = dict(interval=10) -# optimizer -optimizer = dict( - type='SGD', - lr=0.01, - momentum=0.9, - weight_decay=0.0001, - paramwise_options=dict(bias_lr_mult=2., bias_decay_mult=0.)) -optimizer_config = dict(grad_clip=None) -# learning policy -lr_config = dict( - policy='step', - warmup='linear', - warmup_iters=500, - warmup_ratio=1.0 / 3, - step=[8, 11]) - -total_epochs = 12 - -find_unused_parameters = False diff --git a/easycv/utils/checkpoint.py b/easycv/utils/checkpoint.py index 4c987c83..c674eba5 100644 --- a/easycv/utils/checkpoint.py +++ b/easycv/utils/checkpoint.py @@ -18,7 +18,8 @@ def load_checkpoint(model, filename, map_location='cpu', strict=False, - logger=None): + logger=None, + revise_keys=[(r'^module\.', '')]): """Load checkpoint from a file or URI. Args: @@ -30,6 +31,10 @@ def load_checkpoint(model, strict (bool): Whether to allow different params for the model and checkpoint. logger (:mod:`logging.Logger` or None): The logger for error message. + revise_keys (list): A list of customized keywords to modify the + state_dict in checkpoint. Each item is a (pattern, replacement) + pair of the regular expression operations. Default: strip + the prefix 'module.' by [(r'^module\\.', '')]. Returns: dict or OrderedDict: The loaded checkpoint. @@ -62,12 +67,14 @@ def load_checkpoint(model, ) and torch.distributed.is_initialized(): torch.distributed.barrier() filename = cache_file + return mmcv_load_checkpoint( model, filename, map_location=map_location, strict=strict, - logger=logger) + logger=logger, + revise_keys=revise_keys) def save_checkpoint(model, filename, optimizer=None, meta=None): diff --git a/easycv/utils/config_tools.py b/easycv/utils/config_tools.py index 90386c89..85fa452b 100644 --- a/easycv/utils/config_tools.py +++ b/easycv/utils/config_tools.py @@ -7,6 +7,8 @@ from mmcv import Config, import_modules_from_strings +import easycv +from easycv.file import io from easycv.framework.errors import IOError, KeyError, ValueError from .user_config_params_utils import check_value_type @@ -28,50 +30,130 @@ def traverse_replace(d, key, value): traverse_replace(v, key, value) -BASE_KEY = '_base_' - - -# To find base cfg in 'easycv/configs/', base_cfg_name should be 'configs/xx/xx.py' -# TODO: reset the api, keep the same way as mmcv `Config.fromfile` -def check_base_cfg_path(base_cfg_name='configs/base.py', ori_filename=None): - - if base_cfg_name == '../../base.py': - # To becompatible with previous config - base_cfg_name = 'configs/base.py' +class WrapperConfig(Config): + """A facility for config and config files. + + It supports common file formats as configs: python/json/yaml. The interface + is the same as a dict object and also allows access config values as + attributes. + + Example: + >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) + >>> cfg.a + 1 + >>> cfg.b + {'b1': [0, 1]} + >>> cfg.b.b1 + [0, 1] + >>> cfg = Config.fromfile('tests/data/config/a.py') + >>> cfg.filename + "/home/kchen/projects/mmcv/tests/data/config/a.py" + >>> cfg.item4 + 'test' + >>> cfg + "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: " + "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" + """ - base_cfg_dir_1 = osp.abspath(osp.dirname( - osp.dirname(__file__))) # easycv_package_root_path - base_cfg_path_1 = osp.join(base_cfg_dir_1, base_cfg_name) - print('Read base config from', base_cfg_path_1) - if osp.exists(base_cfg_path_1): - return base_cfg_path_1 + @staticmethod + def _substitute_predefined_vars(filename, + temp_config_name, + first_order_params=None): + """ + Override Config._substitute_predefined_vars. + Supports first-order parameter reuse to avoid rebuilding custom config.py templates. + + Args: + filename (str): Original script file. + temp_config_name (str): Template script file. + first_order_params (dict): first-order parameters. + + Returns: + No return value. + + """ + file_dirname = osp.dirname(filename) + file_basename = osp.basename(filename) + file_basename_no_extension = osp.splitext(file_basename)[0] + file_extname = osp.splitext(filename)[1] + support_templates = dict( + fileDirname=file_dirname, + fileBasename=file_basename, + fileBasenameNoExtension=file_basename_no_extension, + fileExtname=file_extname) + with open(filename, encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + left_match, right_match = '{([', '])}' + match_list = [] + line_list = [] + for line in f: + # Push and pop control regular item matching + match_length_before = len(match_list) + for single_str in line: + if single_str in left_match: + match_list.append(single_str) + if single_str in right_match: + match_list.pop() + match_length_after = len(match_list) + + key = line.split('=')[0].strip() + # Check whether it is a first-order parameter + if match_length_before == match_length_after == 0 and first_order_params and key in first_order_params: + value = first_order_params[key] + # repr() is used to convert the data into a string form (in the form of a Python expression) suitable for the interpreter to read + line = ' '.join([key, '=', repr(value)]) + '\n' + + line_list.append(line) + config_file = ''.join(line_list) + + for key, value in support_templates.items(): + regexp = r'\{\{\s*' + str(key) + r'\s*\}\}' + value = value.replace('\\', '/') + config_file = re.sub(regexp, value, config_file) + with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file: + tmp_config_file.write(config_file) + + +def check_base_cfg_path(base_cfg_name='configs/base.py', + father_cfg_name=None, + easycv_root=None): + """ + Concatenate paths by parsing path rules. - base_cfg_dir_2 = osp.dirname(base_cfg_dir_1) # upper level dir - base_cfg_path_2 = osp.join(base_cfg_dir_2, base_cfg_name) - print('Read base config from', base_cfg_path_2) - if osp.exists(base_cfg_path_2): - return base_cfg_path_2 + for example(pseudo-code): + 1. 'configs' in base_cfg_name or 'benchmarks' in base_cfg_name: + base_cfg_name = easycv_root + base_cfg_name - # relative to ori_filename - ori_cfg_dir = osp.dirname(ori_filename) - base_cfg_path_3 = osp.join(ori_cfg_dir, base_cfg_name) - base_cfg_path_3 = osp.abspath(osp.expanduser(base_cfg_path_3)) - if osp.exists(base_cfg_path_3): - return base_cfg_path_3 + 2. 'configs' not in base_cfg_name and 'benchmarks' not in base_cfg_name: + base_cfg_name = father_cfg_name + base_cfg_name - raise ValueError('%s not Found' % base_cfg_name) + """ + parse_base_cfg = base_cfg_name.split('/') + if parse_base_cfg[0] == 'configs' or parse_base_cfg[0] == 'benchmarks': + if easycv_root is not None: + base_cfg_name = osp.join(easycv_root, base_cfg_name) + else: + if father_cfg_name is not None: + _parse_base_path_list = base_cfg_name.split('/') + parse_base_path_list = copy.deepcopy(_parse_base_path_list) + parse_ori_path_list = father_cfg_name.split('/') + parse_ori_path_list.pop() + for filename in _parse_base_path_list: + if filename == '.': + parse_base_path_list.pop(0) + elif filename == '..': + parse_base_path_list.pop(0) + parse_ori_path_list.pop() + else: + break + base_cfg_name = '/'.join(parse_ori_path_list + + parse_base_path_list) + + return base_cfg_name # Read config without __base__ -def mmcv_file2dict_raw(ori_filename): - filename = osp.abspath(osp.expanduser(ori_filename)) - if not osp.isfile(filename): - if ori_filename.startswith('configs/'): - # read configs/config_templates/detection_oss.py - filename = check_base_cfg_path(ori_filename) - else: - raise ValueError('%s and %s not Found' % (ori_filename, filename)) - +def mmcv_file2dict_raw(filename, first_order_params=None): fileExtname = osp.splitext(filename)[1] if fileExtname not in ['.py', '.json', '.yaml', '.yml']: raise IOError('Only py/yml/yaml/json type are supported now!') @@ -82,7 +164,12 @@ def mmcv_file2dict_raw(ori_filename): if platform.system() == 'Windows': temp_config_file.close() temp_config_name = osp.basename(temp_config_file.name) - Config._substitute_predefined_vars(filename, temp_config_file.name) + if first_order_params is not None: + WrapperConfig._substitute_predefined_vars(filename, + temp_config_file.name, + first_order_params) + else: + Config._substitute_predefined_vars(filename, temp_config_file.name) if filename.endswith('.py'): temp_module_name = osp.splitext(temp_config_name)[0] sys.path.insert(0, temp_config_dir) @@ -110,11 +197,13 @@ def mmcv_file2dict_raw(ori_filename): # Reac config with __base__ -def mmcv_file2dict_base(ori_filename): - cfg_dict, cfg_text = mmcv_file2dict_raw(ori_filename) +def mmcv_file2dict_base(ori_filename, + first_order_params=None, + easycv_root=None): + cfg_dict, cfg_text = mmcv_file2dict_raw(ori_filename, first_order_params) + BASE_KEY = '_base_' if BASE_KEY in cfg_dict: - # cfg_dir = osp.dirname(filename) base_filename = cfg_dict.pop(BASE_KEY) base_filename = base_filename if isinstance(base_filename, list) else [base_filename] @@ -122,8 +211,10 @@ def mmcv_file2dict_base(ori_filename): cfg_dict_list = list() cfg_text_list = list() for f in base_filename: - base_cfg_path = check_base_cfg_path(f, ori_filename) - _cfg_dict, _cfg_text = mmcv_file2dict_base(base_cfg_path) + base_cfg_path = check_base_cfg_path( + f, ori_filename, easycv_root=easycv_root) + _cfg_dict, _cfg_text = mmcv_file2dict_base(base_cfg_path, + first_order_params) cfg_dict_list.append(_cfg_dict) cfg_text_list.append(_cfg_text) @@ -143,10 +234,80 @@ def mmcv_file2dict_base(ori_filename): return cfg_dict, cfg_text +def grouping_params(user_config_params): + first_order_params, multi_order_params = {}, {} + for full_key, v in user_config_params.items(): + key_list = full_key.split('.') + if len(key_list) == 1: + first_order_params[full_key] = v + else: + multi_order_params[full_key] = v + + return first_order_params, multi_order_params + + +def adapt_pai_params(cfg_dict, class_list_params=None): + """ + The user passes in the class_list_params. + + Args: + cfg_dict (dict): All parameters of cfg. + class_list_params (list): class_list_params[1] is num_classes. + class_list_params[0] supports three ways to build parameters. + str(.txt) parameter construction method: 0, 1, 2 or 0, \n, 1, \n, 2\n or 0, \n, 1, 2 or person, dog, cat. + list parameter construction method: '[0, 1, 2]' or '[person, dog, cat]' + '' parameter construction method: The default setting is str(0) - str(num_classes - 1) + + Returns: + cfg_dict (dict): Add the cfg of export and oss. + + """ + if class_list_params is not None: + class_list, num_classes = class_list_params[0], class_list_params[1] + if '.txt' in class_list: + cfg_dict['class_list'] = [] + with open(class_list, 'r', encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + lines = f.readlines() + for line in lines: + line = line.strip().strip(',').replace(' ', '').split(',') + cfg_dict['class_list'].extend(line) + elif len(class_list) > 0: + cfg_dict['class_list'] = list(map(str, class_list)) + else: + cfg_dict['class_list'] = list(map(str, range(0, num_classes))) + + # export config + cfg_dict['export'] = dict(export_neck=True) + cfg_dict['checkpoint_sync_export'] = True + # oss config + cfg_dict['oss_sync_config'] = dict( + other_file_list=['**/events.out.tfevents*', '**/*log*']) + cfg_dict['oss_io_config'] = dict( + ak_id='your oss ak id', + ak_secret='your oss ak secret', + hosts='oss-cn-zhangjiakou.aliyuncs.com', + buckets=['your_bucket_2']) + return cfg_dict + + +def init_path(ori_filename): + easycv_root = osp.dirname(easycv.__file__) # easycv package root path + parse_ori_filename = ori_filename.split('/') + if parse_ori_filename[0] == 'configs' or parse_ori_filename[ + 0] == 'benchmarks': + if osp.exists(osp.join(easycv_root, ori_filename)): + ori_filename = osp.join(easycv_root, ori_filename) + + return ori_filename, easycv_root + + # gen mmcv.Config def mmcv_config_fromfile(ori_filename): + ori_filename, easycv_root = init_path(ori_filename) - cfg_dict, cfg_text = mmcv_file2dict_base(ori_filename) + cfg_dict, cfg_text = mmcv_file2dict_base( + ori_filename, easycv_root=easycv_root) if cfg_dict.get('custom_imports', None): import_modules_from_strings(**cfg_dict['custom_imports']) @@ -154,6 +315,46 @@ def mmcv_config_fromfile(ori_filename): return Config(cfg_dict, cfg_text=cfg_text, filename=ori_filename) +def pai_config_fromfile(ori_filename, + user_config_params=None, + model_type=None): + ori_filename, easycv_root = init_path(ori_filename) + + if user_config_params is not None: + # set class_list + class_list_params = None + if 'class_list' in user_config_params: + class_list = user_config_params.pop('class_list') + for key, value in user_config_params.items(): + if 'num_classes' in key: + class_list_params = [class_list, value] + break + + # grouping params + first_order_params, multi_order_params = grouping_params( + user_config_params) + else: + class_list_params, first_order_params, multi_order_params = None, None, None + + # replace first-order parameters + cfg_dict, cfg_text = mmcv_file2dict_base( + ori_filename, first_order_params, easycv_root=easycv_root) + + # Add export and oss ​​related configuration to adapt to pai platform + if model_type: + cfg_dict = adapt_pai_params(cfg_dict, class_list_params) + + if cfg_dict.get('custom_imports', None): + import_modules_from_strings(**cfg_dict['custom_imports']) + + cfg = Config(cfg_dict, cfg_text=cfg_text, filename=ori_filename) + + # replace multi-order parameters + if multi_order_params: + cfg.merge_from_dict(multi_order_params) + return cfg + + # get the true value for ori_key in cfg_dict def get_config_class_value(cfg_dict, ori_key, dict_mem_helper): if ori_key in dict_mem_helper: @@ -320,21 +521,27 @@ def validate_export_config(cfg): CONFIG_TEMPLATE_ZOO = { - # detection - 'YOLOX': 'configs/config_templates/yolox.py', - 'YOLOX_ITAG': 'configs/config_templates/yolox_itag.py', - # cls - 'CLASSIFICATION': 'configs/config_templates/classification.py', - 'CLASSIFICATION_OSS': 'configs/config_templates/classification_oss.py', - 'CLASSIFICATION_TFRECORD_OSS': - 'configs/config_templates/classification_tfrecord_oss.py', + 'CLASSIFICATION_RESNET': + 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py', + 'CLASSIFICATION_RESNEXT': + 'configs/classification/imagenet/resnext/imagenet_resnext50-32x4d_jpg.py', + 'CLASSIFICATION_HRNET': + 'configs/classification/imagenet/hrnet/imagenet_hrnetw18_jpg.py', + 'CLASSIFICATION_VIT': + 'configs/classification/imagenet/vit/imagenet_vit_base_patch16_224_jpg.py', + 'CLASSIFICATION_SWINT': + 'configs/classification/imagenet/swint/imagenet_swin_tiny_patch4_window7_224_jpg.py', # metric learning - 'METRICLEARNING_TFRECORD_OSS': - 'configs/config_templates/metric_learning/softmaxbased_tfrecord_oss.py', + 'METRICLEARNING': + 'configs/metric_learning/imagenet_timm_softmaxbased_jpg.py', 'MODELPARALLEL_METRICLEARNING': - 'configs/config_templates/metric_learning/modelparallel_softmaxbased_tfrecord_oss.py', + 'configs/metric_learning/imagenet_timm_modelparallel_softmaxbased_jpg.py', + + # detection + 'YOLOX': 'configs/config_templates/yolox.py', + 'YOLOX_ITAG': 'configs/config_templates/yolox_itag.py', # ssl 'MOCO_R50_TFRECORD': 'configs/config_templates/moco_r50_tfrecord.py', diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 0bd25eee..9d8d4fd8 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -22,7 +22,7 @@ scikit-image sklearn tensorboard thop -timm>=0.4.9 +timm==0.5.4 wget xtcocotools yacs diff --git a/setup.py b/setup.py index 90f589d7..583dd776 100644 --- a/setup.py +++ b/setup.py @@ -154,6 +154,7 @@ def pack_resource(): proj_dir = root_dir + 'easycv/' shutil.copytree('./easycv', proj_dir) shutil.copytree('./configs', proj_dir + 'configs') + shutil.copytree('./benchmarks', proj_dir + 'benchmarks') shutil.copytree('./tools', proj_dir + 'tools') shutil.copytree('./resource', proj_dir + 'resource') shutil.copytree('./requirements', 'package/requirements') @@ -177,7 +178,8 @@ def pack_resource(): author_email='easycv@list.alibaba-inc.com', keywords='self-supvervised, classification, vision', url='https://github.com/alibaba/EasyCV.git', - packages=find_packages(exclude=('configs', 'tools', 'demo')), + packages=find_packages( + exclude=('configs', 'benchmarks', 'tools', 'demo')), include_package_data=True, classifiers=[ 'Development Status :: 4 - Beta', diff --git a/tests/configs/test_adapt_pai_params.py b/tests/configs/test_adapt_pai_params.py new file mode 100644 index 00000000..64058d23 --- /dev/null +++ b/tests/configs/test_adapt_pai_params.py @@ -0,0 +1,65 @@ +import os.path as osp +import unittest + +from tests.ut_config import CLASS_LIST_TEST + +from easycv.utils.config_tools import adapt_pai_params + + +class AdaptPaiParamsTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + def test_adapt_pai_params_0(self): + cfg_dict = {} + class_list_params = ['', 8] + cfg_dict = adapt_pai_params( + cfg_dict, class_list_params=class_list_params) + + self.assertEqual(cfg_dict['class_list'], + ['0', '1', '2', '3', '4', '5', '6', '7']) + + def test_adapt_pai_params_1(self): + cfg_dict = {} + class_list_params = [['person', 'cat', 'dog'], 8] + cfg_dict = adapt_pai_params( + cfg_dict, class_list_params=class_list_params) + + self.assertEqual(cfg_dict['class_list'], ['person', 'cat', 'dog']) + + def test_adapt_pai_params_2(self): + cfg_dict = {} + class_list_params = [[0, 1, 2], 8] + cfg_dict = adapt_pai_params( + cfg_dict, class_list_params=class_list_params) + + self.assertEqual(cfg_dict['class_list'], ['0', '1', '2']) + + def test_adapt_pai_params_3(self): + cfg_dict = {} + class_list_params = [ + osp.join(CLASS_LIST_TEST, 'class_list_int_test.txt'), 8 + ] + cfg_dict = adapt_pai_params( + cfg_dict, class_list_params=class_list_params) + + self.assertEqual(cfg_dict['class_list'], + ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']) + + def test_adapt_pai_params_4(self): + cfg_dict = {} + class_list_params = [ + osp.join(CLASS_LIST_TEST, 'class_list_str_test.txt'), 8 + ] + cfg_dict = adapt_pai_params( + cfg_dict, class_list_params=class_list_params) + + self.assertEqual(cfg_dict['class_list'], [ + 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', + 'horse', 'ship', 'truck' + ]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/configs/test_check_base_cfg_path.py b/tests/configs/test_check_base_cfg_path.py new file mode 100644 index 00000000..81190c0c --- /dev/null +++ b/tests/configs/test_check_base_cfg_path.py @@ -0,0 +1,83 @@ +import unittest + +from easycv.utils.config_tools import check_base_cfg_path + + +class CheckPathTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + def test_check_0(self): + base_cfg_name = 'configs/base.py' + easycv_root = '/root/easycv' + father_cfg_name = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py' + base_cfg_name = check_base_cfg_path( + base_cfg_name=base_cfg_name, + father_cfg_name=father_cfg_name, + easycv_root=easycv_root) + + self.assertEqual(base_cfg_name, '/root/easycv/configs/base.py') + + def test_check_1(self): + base_cfg_name = 'benchmarks/base.py' + easycv_root = '/root/easycv' + father_cfg_name = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py' + base_cfg_name = check_base_cfg_path( + base_cfg_name=base_cfg_name, + father_cfg_name=father_cfg_name, + easycv_root=easycv_root) + + self.assertEqual(base_cfg_name, '/root/easycv/benchmarks/base.py') + + def test_check_2(self): + base_cfg_name = '../base.py' + easycv_root = '/root/easycv' + father_cfg_name = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py' + base_cfg_name = check_base_cfg_path( + base_cfg_name=base_cfg_name, + father_cfg_name=father_cfg_name, + easycv_root=easycv_root) + + self.assertEqual(base_cfg_name, + 'configs/classification/imagenet/base.py') + + def test_check_3(self): + base_cfg_name = 'common/base.py' + easycv_root = '/root/easycv' + father_cfg_name = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py' + base_cfg_name = check_base_cfg_path( + base_cfg_name=base_cfg_name, + father_cfg_name=father_cfg_name, + easycv_root=easycv_root) + + self.assertEqual( + base_cfg_name, + 'configs/classification/imagenet/resnet/common/base.py') + + def test_check_4(self): + base_cfg_name = 'common/base.py' + easycv_root = '/root/easycv' + father_cfg_name = 'data/classification/imagenet/resnet/imagenet_resnet50_jpg.py' + base_cfg_name = check_base_cfg_path( + base_cfg_name=base_cfg_name, + father_cfg_name=father_cfg_name, + easycv_root=easycv_root) + + self.assertEqual(base_cfg_name, + 'data/classification/imagenet/resnet/common/base.py') + + def test_check_5(self): + base_cfg_name = '../base.py' + easycv_root = '/root/easycv' + father_cfg_name = 'data/classification/imagenet/resnet/imagenet_resnet50_jpg.py' + base_cfg_name = check_base_cfg_path( + base_cfg_name=base_cfg_name, + father_cfg_name=father_cfg_name, + easycv_root=easycv_root) + + self.assertEqual(base_cfg_name, 'data/classification/imagenet/base.py') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/configs/test_template.py b/tests/configs/test_template.py deleted file mode 100644 index 2153d342..00000000 --- a/tests/configs/test_template.py +++ /dev/null @@ -1 +0,0 @@ -# TODO: add template config unittest diff --git a/tests/hooks/test_export_hook.py b/tests/hooks/test_export_hook.py index 93dd4781..958344d1 100644 --- a/tests/hooks/test_export_hook.py +++ b/tests/hooks/test_export_hook.py @@ -64,7 +64,7 @@ def _build_model(): @MODELS.register_module() class TestExportModel(nn.Module): - def __init__(self): + def __init__(self, pretrained=False): super().__init__() self.linear = nn.Linear(2, 1) diff --git a/tests/tools/test_classification_train.py b/tests/tools/test_classification_train.py index 2d94827e..e04739e0 100644 --- a/tests/tools/test_classification_train.py +++ b/tests/tools/test_classification_train.py @@ -6,10 +6,10 @@ import tempfile import unittest -from mmcv import Config from tests.ut_config import SMALL_IMAGENET_RAW_LOCAL from easycv.file import io +from easycv.utils.config_tools import mmcv_config_fromfile, pai_config_fromfile from easycv.utils.test_util import run_in_subprocess sys.path.append(os.path.dirname(os.path.realpath(__file__))) @@ -51,6 +51,21 @@ SMALL_IMAGENET_DATA_ROOT + 'meta/val_labeled_100.txt', 'model.train_preprocess': ['randomErasing', 'mixUp'] } +}, { + 'config_file': + 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py', + 'cfg_options': { + **_COMMON_OPTIONS, 'data_train_root': + SMALL_IMAGENET_DATA_ROOT + 'train/', + 'data_train_list': + SMALL_IMAGENET_DATA_ROOT + 'meta/train_labeled_200.txt', + 'data_test_root': SMALL_IMAGENET_DATA_ROOT + 'validation/', + 'data_test_list': + SMALL_IMAGENET_DATA_ROOT + 'meta/val_labeled_100.txt', + 'image_resize2': [224, 224], + 'save_epochs': 1, + 'eval_epochs': 1 + } }] @@ -62,17 +77,22 @@ def setUp(self): def tearDown(self): super().tearDown() - def _base_train(self, train_cfgs): + def _base_train(self, train_cfgs, adapt_pai=False): cfg_file = train_cfgs.pop('config_file') cfg_options = train_cfgs.pop('cfg_options', None) work_dir = train_cfgs.pop('work_dir', None) if not work_dir: work_dir = tempfile.TemporaryDirectory().name - cfg = Config.fromfile(cfg_file) - if cfg_options is not None: - cfg.merge_from_dict(cfg_options) + if adapt_pai: + cfg = pai_config_fromfile(cfg_file, user_config_params=cfg_options) cfg.eval_pipelines[0].data = cfg.data.val + else: + cfg = mmcv_config_fromfile(cfg_file) + if cfg_options is not None: + cfg.merge_from_dict(cfg_options) + cfg.eval_pipelines[0].data = cfg.data.val + tmp_cfg_file = tempfile.NamedTemporaryFile(suffix='.py').name cfg.dump(tmp_cfg_file) @@ -100,6 +120,11 @@ def test_classification_mixup(self): self._base_train(train_cfgs) + def test_classification_pai(self): + train_cfgs = copy.deepcopy(TRAIN_CONFIGS[2]) + + self._base_train(train_cfgs, adapt_pai=True) + if __name__ == '__main__': unittest.main() diff --git a/tests/tools/test_mae_train.py b/tests/tools/test_mae_train.py index ffdafb60..c3e1b956 100644 --- a/tests/tools/test_mae_train.py +++ b/tests/tools/test_mae_train.py @@ -76,7 +76,6 @@ def _base_train(self, train_cfgs): if not work_dir: work_dir = tempfile.TemporaryDirectory().name - # cfg = Config.fromfile(cfg_file) cfg = mmcv_config_fromfile(cfg_file) if cfg_options is not None: cfg.merge_from_dict(cfg_options) diff --git a/tests/ut_config.py b/tests/ut_config.py index d5d76d90..340d6bd5 100644 --- a/tests/ut_config.py +++ b/tests/ut_config.py @@ -42,6 +42,8 @@ CIFAR100_LOCAL = os.path.join(BASE_LOCAL_PATH, 'data/classification/cifar100') SAMLL_IMAGENET1K_RAW_LOCAL = os.path.join(BASE_LOCAL_PATH, 'datasets/imagenet-1k/imagenet_raw') +CLASS_LIST_TEST = os.path.join(BASE_LOCAL_PATH, + 'data/classification/class_list_test') SMALL_IMAGENET_TFRECORD_LOCAL = os.path.join( BASE_LOCAL_PATH, 'data/classification/small_imagenet_tfrecord/') diff --git a/tools/train.py b/tools/train.py index fab34e87..2b44c7c4 100644 --- a/tools/train.py +++ b/tools/train.py @@ -25,6 +25,7 @@ import torch import torch.distributed as dist from mmcv.runner import init_dist +from mmcv import DictAction from easycv import __version__ from easycv.apis import init_random_seed, set_random_seed, train_model @@ -35,9 +36,9 @@ from easycv.utils.collect_env import collect_env from easycv.utils.logger import get_root_logger from easycv.utils import mmlab_utils -from easycv.utils.config_tools import traverse_replace -from easycv.utils.config_tools import (CONFIG_TEMPLATE_ZOO, - mmcv_config_fromfile, rebuild_config) +from easycv.utils.config_tools import (traverse_replace, CONFIG_TEMPLATE_ZOO, + mmcv_config_fromfile, + pai_config_fromfile) from easycv.utils.dist_utils import get_device, is_master from easycv.utils.setup_env import setup_multi_processes @@ -93,9 +94,14 @@ def parse_args(): ) parser.add_argument( '--user_config_params', - nargs=argparse.REMAINDER, - default=None, - help='modify config options using the command-line') + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed. Single quote double quote equivalent.') args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: @@ -127,12 +133,13 @@ def main(): pass args.config = tpath - cfg = mmcv_config_fromfile(args.config) - if args.user_config_params is not None: - # assert args.model_type is not None, 'model_type must be setted' - # rebuild config by user config params - cfg = rebuild_config(cfg, args.user_config_params) + # build cfg + if args.user_config_params is None: + cfg = mmcv_config_fromfile(args.config) + else: + cfg = pai_config_fromfile(args.config, args.user_config_params, + args.model_type) # set multi-process settings setup_multi_processes(cfg) @@ -275,7 +282,8 @@ def main(): pin_memory=getattr(cfg.data, 'pin_memory', False), replace=getattr(cfg.data, 'sampling_replace', False), seed=cfg.seed, - drop_last=getattr(cfg.data, 'drop_last', False), + # The default should be set to True, because sometimes the last batch is not sampled enough, causing an error in batchnorm + drop_last=getattr(cfg.data, 'drop_last', True), reuse_worker_cache=cfg.data.get('reuse_worker_cache', False), persistent_workers=cfg.data.get('persistent_workers', False), collate_hooks=cfg.data.get('train_collate_hooks', []),