diff --git a/configs/classification/imagenet/common/classification_base.py b/configs/classification/imagenet/common/classification_base.py new file mode 100644 index 00000000..4bbd3183 --- /dev/null +++ b/configs/classification/imagenet/common/classification_base.py @@ -0,0 +1,20 @@ +_base_ = 'configs/base.py' + +log_config = dict( + interval=10, + hooks=[dict(type='TextLoggerHook'), + dict(type='TensorboardLoggerHook')]) + +image_size2 = 224 +image_size1 = int((256 / 224) * image_size2) +img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + +predict = dict( + type='ClassificationPredictor', + pipelines=[ + dict(type='Resize', size=image_size1), + dict(type='CenterCrop', size=image_size2), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Collect', keys=['img']) + ]) diff --git a/configs/classification/imagenet/common/dataset/imagenet_classification.py b/configs/classification/imagenet/common/dataset/imagenet_classification.py new file mode 100644 index 00000000..2d385e07 --- /dev/null +++ b/configs/classification/imagenet/common/dataset/imagenet_classification.py @@ -0,0 +1,157 @@ +_base_ = '../classification_base.py' + +class_list = [ + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', + '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', + '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', + '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', + '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', + '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', + '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', + '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', + '98', '99', '100', '101', '102', '103', '104', '105', '106', '107', '108', + '109', '110', '111', '112', '113', '114', '115', '116', '117', '118', + '119', '120', '121', '122', '123', '124', '125', '126', '127', '128', + '129', '130', '131', '132', '133', '134', '135', '136', '137', '138', + '139', '140', '141', '142', '143', '144', '145', '146', '147', '148', + '149', '150', '151', '152', '153', '154', '155', '156', '157', '158', + '159', '160', '161', '162', '163', '164', '165', '166', '167', '168', + '169', '170', '171', '172', '173', '174', '175', '176', '177', '178', + '179', '180', '181', '182', '183', '184', '185', '186', '187', '188', + '189', '190', '191', '192', '193', '194', '195', '196', '197', '198', + '199', '200', '201', '202', '203', '204', '205', '206', '207', '208', + '209', '210', '211', '212', '213', '214', '215', '216', '217', '218', + '219', '220', '221', '222', '223', '224', '225', '226', '227', '228', + '229', '230', '231', '232', '233', '234', '235', '236', '237', '238', + '239', '240', '241', '242', '243', '244', '245', '246', '247', '248', + '249', '250', '251', '252', '253', '254', '255', '256', '257', '258', + '259', '260', '261', '262', '263', '264', '265', '266', '267', '268', + '269', '270', '271', '272', '273', '274', '275', '276', '277', '278', + '279', '280', '281', '282', '283', '284', '285', '286', '287', '288', + '289', '290', '291', '292', '293', '294', '295', '296', '297', '298', + '299', '300', '301', '302', '303', '304', '305', '306', '307', '308', + '309', '310', '311', '312', '313', '314', '315', '316', '317', '318', + '319', '320', '321', '322', '323', '324', '325', '326', '327', '328', + '329', '330', '331', '332', '333', '334', '335', '336', '337', '338', + '339', '340', '341', '342', '343', '344', '345', '346', '347', '348', + '349', '350', '351', '352', '353', '354', '355', '356', '357', '358', + '359', '360', '361', '362', '363', '364', '365', '366', '367', '368', + '369', '370', '371', '372', '373', '374', '375', '376', '377', '378', + '379', '380', '381', '382', '383', '384', '385', '386', '387', '388', + '389', '390', '391', '392', '393', '394', '395', '396', '397', '398', + '399', '400', '401', '402', '403', '404', '405', '406', '407', '408', + '409', '410', '411', '412', '413', '414', '415', '416', '417', '418', + '419', '420', '421', '422', '423', '424', '425', '426', '427', '428', + '429', '430', '431', '432', '433', '434', '435', '436', '437', '438', + '439', '440', '441', '442', '443', '444', '445', '446', '447', '448', + '449', '450', '451', '452', '453', '454', '455', '456', '457', '458', + '459', '460', '461', '462', '463', '464', '465', '466', '467', '468', + '469', '470', '471', '472', '473', '474', '475', '476', '477', '478', + '479', '480', '481', '482', '483', '484', '485', '486', '487', '488', + '489', '490', '491', '492', '493', '494', '495', '496', '497', '498', + '499', '500', '501', '502', '503', '504', '505', '506', '507', '508', + '509', '510', '511', '512', '513', '514', '515', '516', '517', '518', + '519', '520', '521', '522', '523', '524', '525', '526', '527', '528', + '529', '530', '531', '532', '533', '534', '535', '536', '537', '538', + '539', '540', '541', '542', '543', '544', '545', '546', '547', '548', + '549', '550', '551', '552', '553', '554', '555', '556', '557', '558', + '559', '560', '561', '562', '563', '564', '565', '566', '567', '568', + '569', '570', '571', '572', '573', '574', '575', '576', '577', '578', + '579', '580', '581', '582', '583', '584', '585', '586', '587', '588', + '589', '590', '591', '592', '593', '594', '595', '596', '597', '598', + '599', '600', '601', '602', '603', '604', '605', '606', '607', '608', + '609', '610', '611', '612', '613', '614', '615', '616', '617', '618', + '619', '620', '621', '622', '623', '624', '625', '626', '627', '628', + '629', '630', '631', '632', '633', '634', '635', '636', '637', '638', + '639', '640', '641', '642', '643', '644', '645', '646', '647', '648', + '649', '650', '651', '652', '653', '654', '655', '656', '657', '658', + '659', '660', '661', '662', '663', '664', '665', '666', '667', '668', + '669', '670', '671', '672', '673', '674', '675', '676', '677', '678', + '679', '680', '681', '682', '683', '684', '685', '686', '687', '688', + '689', '690', '691', '692', '693', '694', '695', '696', '697', '698', + '699', '700', '701', '702', '703', '704', '705', '706', '707', '708', + '709', '710', '711', '712', '713', '714', '715', '716', '717', '718', + '719', '720', '721', '722', '723', '724', '725', '726', '727', '728', + '729', '730', '731', '732', '733', '734', '735', '736', '737', '738', + '739', '740', '741', '742', '743', '744', '745', '746', '747', '748', + '749', '750', '751', '752', '753', '754', '755', '756', '757', '758', + '759', '760', '761', '762', '763', '764', '765', '766', '767', '768', + '769', '770', '771', '772', '773', '774', '775', '776', '777', '778', + '779', '780', '781', '782', '783', '784', '785', '786', '787', '788', + '789', '790', '791', '792', '793', '794', '795', '796', '797', '798', + '799', '800', '801', '802', '803', '804', '805', '806', '807', '808', + '809', '810', '811', '812', '813', '814', '815', '816', '817', '818', + '819', '820', '821', '822', '823', '824', '825', '826', '827', '828', + '829', '830', '831', '832', '833', '834', '835', '836', '837', '838', + '839', '840', '841', '842', '843', '844', '845', '846', '847', '848', + '849', '850', '851', '852', '853', '854', '855', '856', '857', '858', + '859', '860', '861', '862', '863', '864', '865', '866', '867', '868', + '869', '870', '871', '872', '873', '874', '875', '876', '877', '878', + '879', '880', '881', '882', '883', '884', '885', '886', '887', '888', + '889', '890', '891', '892', '893', '894', '895', '896', '897', '898', + '899', '900', '901', '902', '903', '904', '905', '906', '907', '908', + '909', '910', '911', '912', '913', '914', '915', '916', '917', '918', + '919', '920', '921', '922', '923', '924', '925', '926', '927', '928', + '929', '930', '931', '932', '933', '934', '935', '936', '937', '938', + '939', '940', '941', '942', '943', '944', '945', '946', '947', '948', + '949', '950', '951', '952', '953', '954', '955', '956', '957', '958', + '959', '960', '961', '962', '963', '964', '965', '966', '967', '968', + '969', '970', '971', '972', '973', '974', '975', '976', '977', '978', + '979', '980', '981', '982', '983', '984', '985', '986', '987', '988', + '989', '990', '991', '992', '993', '994', '995', '996', '997', '998', '999' +] + +data_source_type = 'ClsSourceImageList' +data_train_list = 'data/imagenet_raw/meta/train_labeled.txt' +data_train_root = 'data/imagenet_raw/train/' +data_test_list = 'data/imagenet_raw/meta/val_labeled.txt' +data_test_root = 'data/imagenet_raw/validation/' +image_size2 = 224 +image_size1 = int((256 / 224) * image_size2) + +dataset_type = 'ClsDataset' +img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +train_pipeline = [ + dict(type='RandomResizedCrop', size=image_size2), + dict(type='RandomHorizontalFlip'), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Collect', keys=['img', 'gt_labels']) +] +test_pipeline = [ + dict(type='Resize', size=image_size1), + dict(type='CenterCrop', size=image_size2), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Collect', keys=['img', 'gt_labels']) +] + +data = dict( + imgs_per_gpu=32, # total 256 + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_source=dict( + list_file=data_train_list, + root=data_train_root, + type=data_source_type), + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_source=dict( + list_file=data_test_list, + root=data_test_root, + type=data_source_type), + pipeline=test_pipeline)) + +eval_config = dict(initial=False, interval=1, gpu_collect=True) +eval_pipelines = [ + dict( + mode='test', + data=data['val'], + dist_eval=True, + evaluators=[ + dict(type='ClsEvaluator', topk=(1, 5), class_list=class_list) + ], + ) +] diff --git a/configs/classification/imagenet/hrnet/hrnetw18_b32x8_100e_jpg.py b/configs/classification/imagenet/hrnet/hrnetw18_b32x8_100e_jpg.py index 104db80e..118ffa2a 100644 --- a/configs/classification/imagenet/hrnet/hrnetw18_b32x8_100e_jpg.py +++ b/configs/classification/imagenet/hrnet/hrnetw18_b32x8_100e_jpg.py @@ -1,8 +1,4 @@ -_base_ = 'configs/base.py' -log_config = dict( - interval=10, - hooks=[dict(type='TextLoggerHook'), - dict(type='TensorboardLoggerHook')]) +_base_ = '../common/dataset/imagenet_classification.py' # model settings model = dict( @@ -19,61 +15,6 @@ ), num_classes=1000)) -data_train_list = 'data/imagenet_raw/meta/train_labeled.txt' -data_train_root = 'data/imagenet_raw/train/' -data_test_list = 'data/imagenet_raw/meta/val_labeled.txt' -data_test_root = 'data/imagenet_raw/validation/' -data_all_list = 'data/imagenet_raw/meta/all_labeled.txt' -data_root = 'data/imagenet_raw/' - -dataset_type = 'ClsDataset' -img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) -train_pipeline = [ - dict(type='RandomResizedCrop', size=224), - dict(type='RandomHorizontalFlip'), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Collect', keys=['img', 'gt_labels']) -] -test_pipeline = [ - dict(type='Resize', size=256), - dict(type='CenterCrop', size=224), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Collect', keys=['img', 'gt_labels']) -] - -data = dict( - imgs_per_gpu=32, # total 256 - workers_per_gpu=4, - train=dict( - type=dataset_type, - data_source=dict( - list_file=data_train_list, - root=data_train_root, - type='ClsSourceImageList'), - pipeline=train_pipeline), - val=dict( - type=dataset_type, - data_source=dict( - list_file=data_test_list, - root=data_test_root, - type='ClsSourceImageList'), - pipeline=test_pipeline)) - -eval_config = dict(initial=True, interval=100, gpu_collect=True) -eval_pipelines = [ - dict( - mode='test', - data=data['val'], - dist_eval=True, - evaluators=[dict(type='ClsEvaluator', topk=(1, 5))], - ) -] - -# additional hooks -custom_hooks = [] - # optimizer optimizer = dict(type='SGD', lr=0.05, momentum=0.9, weight_decay=0.0001) diff --git a/configs/classification/imagenet/resnet/resnet50_b32x8_100e_jpg.py b/configs/classification/imagenet/resnet/resnet50_b32x8_100e_jpg.py index 0e999255..f67afd08 100644 --- a/configs/classification/imagenet/resnet/resnet50_b32x8_100e_jpg.py +++ b/configs/classification/imagenet/resnet/resnet50_b32x8_100e_jpg.py @@ -1,109 +1,4 @@ -_base_ = '../../../base.py' -class_list = [ - '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', - '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', - '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', - '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', - '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', - '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', - '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', - '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', - '98', '99', '100', '101', '102', '103', '104', '105', '106', '107', '108', - '109', '110', '111', '112', '113', '114', '115', '116', '117', '118', - '119', '120', '121', '122', '123', '124', '125', '126', '127', '128', - '129', '130', '131', '132', '133', '134', '135', '136', '137', '138', - '139', '140', '141', '142', '143', '144', '145', '146', '147', '148', - '149', '150', '151', '152', '153', '154', '155', '156', '157', '158', - '159', '160', '161', '162', '163', '164', '165', '166', '167', '168', - '169', '170', '171', '172', '173', '174', '175', '176', '177', '178', - '179', '180', '181', '182', '183', '184', '185', '186', '187', '188', - '189', '190', '191', '192', '193', '194', '195', '196', '197', '198', - '199', '200', '201', '202', '203', '204', '205', '206', '207', '208', - '209', '210', '211', '212', '213', '214', '215', '216', '217', '218', - '219', '220', '221', '222', '223', '224', '225', '226', '227', '228', - '229', '230', '231', '232', '233', '234', '235', '236', '237', '238', - '239', '240', '241', '242', '243', '244', '245', '246', '247', '248', - '249', '250', '251', '252', '253', '254', '255', '256', '257', '258', - '259', '260', '261', '262', '263', '264', '265', '266', '267', '268', - '269', '270', '271', '272', '273', '274', '275', '276', '277', '278', - '279', '280', '281', '282', '283', '284', '285', '286', '287', '288', - '289', '290', '291', '292', '293', '294', '295', '296', '297', '298', - '299', '300', '301', '302', '303', '304', '305', '306', '307', '308', - '309', '310', '311', '312', '313', '314', '315', '316', '317', '318', - '319', '320', '321', '322', '323', '324', '325', '326', '327', '328', - '329', '330', '331', '332', '333', '334', '335', '336', '337', '338', - '339', '340', '341', '342', '343', '344', '345', '346', '347', '348', - '349', '350', '351', '352', '353', '354', '355', '356', '357', '358', - '359', '360', '361', '362', '363', '364', '365', '366', '367', '368', - '369', '370', '371', '372', '373', '374', '375', '376', '377', '378', - '379', '380', '381', '382', '383', '384', '385', '386', '387', '388', - '389', '390', '391', '392', '393', '394', '395', '396', '397', '398', - '399', '400', '401', '402', '403', '404', '405', '406', '407', '408', - '409', '410', '411', '412', '413', '414', '415', '416', '417', '418', - '419', '420', '421', '422', '423', '424', '425', '426', '427', '428', - '429', '430', '431', '432', '433', '434', '435', '436', '437', '438', - '439', '440', '441', '442', '443', '444', '445', '446', '447', '448', - '449', '450', '451', '452', '453', '454', '455', '456', '457', '458', - '459', '460', '461', '462', '463', '464', '465', '466', '467', '468', - '469', '470', '471', '472', '473', '474', '475', '476', '477', '478', - '479', '480', '481', '482', '483', '484', '485', '486', '487', '488', - '489', '490', '491', '492', '493', '494', '495', '496', '497', '498', - '499', '500', '501', '502', '503', '504', '505', '506', '507', '508', - '509', '510', '511', '512', '513', '514', '515', '516', '517', '518', - '519', '520', '521', '522', '523', '524', '525', '526', '527', '528', - '529', '530', '531', '532', '533', '534', '535', '536', '537', '538', - '539', '540', '541', '542', '543', '544', '545', '546', '547', '548', - '549', '550', '551', '552', '553', '554', '555', '556', '557', '558', - '559', '560', '561', '562', '563', '564', '565', '566', '567', '568', - '569', '570', '571', '572', '573', '574', '575', '576', '577', '578', - '579', '580', '581', '582', '583', '584', '585', '586', '587', '588', - '589', '590', '591', '592', '593', '594', '595', '596', '597', '598', - '599', '600', '601', '602', '603', '604', '605', '606', '607', '608', - '609', '610', '611', '612', '613', '614', '615', '616', '617', '618', - '619', '620', '621', '622', '623', '624', '625', '626', '627', '628', - '629', '630', '631', '632', '633', '634', '635', '636', '637', '638', - '639', '640', '641', '642', '643', '644', '645', '646', '647', '648', - '649', '650', '651', '652', '653', '654', '655', '656', '657', '658', - '659', '660', '661', '662', '663', '664', '665', '666', '667', '668', - '669', '670', '671', '672', '673', '674', '675', '676', '677', '678', - '679', '680', '681', '682', '683', '684', '685', '686', '687', '688', - '689', '690', '691', '692', '693', '694', '695', '696', '697', '698', - '699', '700', '701', '702', '703', '704', '705', '706', '707', '708', - '709', '710', '711', '712', '713', '714', '715', '716', '717', '718', - '719', '720', '721', '722', '723', '724', '725', '726', '727', '728', - '729', '730', '731', '732', '733', '734', '735', '736', '737', '738', - '739', '740', '741', '742', '743', '744', '745', '746', '747', '748', - '749', '750', '751', '752', '753', '754', '755', '756', '757', '758', - '759', '760', '761', '762', '763', '764', '765', '766', '767', '768', - '769', '770', '771', '772', '773', '774', '775', '776', '777', '778', - '779', '780', '781', '782', '783', '784', '785', '786', '787', '788', - '789', '790', '791', '792', '793', '794', '795', '796', '797', '798', - '799', '800', '801', '802', '803', '804', '805', '806', '807', '808', - '809', '810', '811', '812', '813', '814', '815', '816', '817', '818', - '819', '820', '821', '822', '823', '824', '825', '826', '827', '828', - '829', '830', '831', '832', '833', '834', '835', '836', '837', '838', - '839', '840', '841', '842', '843', '844', '845', '846', '847', '848', - '849', '850', '851', '852', '853', '854', '855', '856', '857', '858', - '859', '860', '861', '862', '863', '864', '865', '866', '867', '868', - '869', '870', '871', '872', '873', '874', '875', '876', '877', '878', - '879', '880', '881', '882', '883', '884', '885', '886', '887', '888', - '889', '890', '891', '892', '893', '894', '895', '896', '897', '898', - '899', '900', '901', '902', '903', '904', '905', '906', '907', '908', - '909', '910', '911', '912', '913', '914', '915', '916', '917', '918', - '919', '920', '921', '922', '923', '924', '925', '926', '927', '928', - '929', '930', '931', '932', '933', '934', '935', '936', '937', '938', - '939', '940', '941', '942', '943', '944', '945', '946', '947', '948', - '949', '950', '951', '952', '953', '954', '955', '956', '957', '958', - '959', '960', '961', '962', '963', '964', '965', '966', '967', '968', - '969', '970', '971', '972', '973', '974', '975', '976', '977', '978', - '979', '980', '981', '982', '983', '984', '985', '986', '987', '988', - '989', '990', '991', '992', '993', '994', '995', '996', '997', '998', '999' -] - -log_config = dict( - interval=10, - hooks=[dict(type='TextLoggerHook'), - dict(type='TensorboardLoggerHook')]) +_base_ = '../common/dataset/imagenet_classification.py' # model settings model = dict( @@ -123,63 +18,6 @@ ), num_classes=1000)) -data_train_list = 'data/imagenet_raw/meta/train_labeled.txt' -data_train_root = 'data/imagenet_raw/train/' -data_test_list = 'data/imagenet_raw/meta/val_labeled.txt' -data_test_root = 'data/imagenet_raw/validation/' -data_all_list = 'data/imagenet_raw/meta/all_labeled.txt' -data_root = 'data/imagenet_raw/' - -dataset_type = 'ClsDataset' -img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) -train_pipeline = [ - dict(type='RandomResizedCrop', size=224), - dict(type='RandomHorizontalFlip'), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Collect', keys=['img', 'gt_labels']) -] -test_pipeline = [ - dict(type='Resize', size=256), - dict(type='CenterCrop', size=224), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Collect', keys=['img', 'gt_labels']) -] - -data = dict( - imgs_per_gpu=32, # total 256 - workers_per_gpu=4, - train=dict( - type=dataset_type, - data_source=dict( - list_file=data_train_list, - root=data_train_root, - type='ClsSourceImageList'), - pipeline=train_pipeline), - val=dict( - type=dataset_type, - data_source=dict( - list_file=data_test_list, - root=data_test_root, - type='ClsSourceImageList'), - pipeline=test_pipeline)) - -eval_config = dict(initial=False, interval=1, gpu_collect=True) -eval_pipelines = [ - dict( - mode='test', - data=data['val'], - dist_eval=True, - evaluators=[ - dict(type='ClsEvaluator', topk=(1, 5), class_list=class_list) - ], - ) -] - -# additional hooks -custom_hooks = [] - # optimizer optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001) @@ -189,13 +27,3 @@ # runtime settings total_epochs = 100 - -predict = dict( - type='ClassificationPredictor', - pipelines=[ - dict(type='Resize', size=256), - dict(type='CenterCrop', size=224), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Collect', keys=['img']) - ]) diff --git a/configs/classification/imagenet/resnext/resnext50-32x4d_b32x8_100e_jpg.py b/configs/classification/imagenet/resnext/resnext50-32x4d_b32x8_100e_jpg.py index 22f0f7c1..adb9ee7f 100644 --- a/configs/classification/imagenet/resnext/resnext50-32x4d_b32x8_100e_jpg.py +++ b/configs/classification/imagenet/resnext/resnext50-32x4d_b32x8_100e_jpg.py @@ -1,8 +1,4 @@ -_base_ = 'configs/base.py' -log_config = dict( - interval=10, - hooks=[dict(type='TextLoggerHook'), - dict(type='TensorboardLoggerHook')]) +_base_ = '../common/dataset/imagenet_classification.py' # model settings model = dict( @@ -24,61 +20,6 @@ ), num_classes=1000)) -data_train_list = 'data/imagenet_raw/meta/train_labeled.txt' -data_train_root = 'data/imagenet_raw/train/' -data_test_list = 'data/imagenet_raw/meta/val_labeled.txt' -data_test_root = 'data/imagenet_raw/val/' -data_all_list = 'data/imagenet_raw/meta/all_labeled.txt' -data_root = 'data/imagenet_raw/' - -dataset_type = 'ClsDataset' -img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) -train_pipeline = [ - dict(type='RandomResizedCrop', size=224), - dict(type='RandomHorizontalFlip'), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Collect', keys=['img', 'gt_labels']) -] -test_pipeline = [ - dict(type='Resize', size=256), - dict(type='CenterCrop', size=224), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Collect', keys=['img', 'gt_labels']) -] - -data = dict( - imgs_per_gpu=32, # total 256 - workers_per_gpu=4, - train=dict( - type=dataset_type, - data_source=dict( - list_file=data_train_list, - root=data_train_root, - type='ClsSourceImageList'), - pipeline=train_pipeline), - val=dict( - type=dataset_type, - data_source=dict( - list_file=data_test_list, - root=data_test_root, - type='ClsSourceImageList'), - pipeline=test_pipeline)) - -eval_config = dict(initial=True, interval=1, gpu_collect=True) -eval_pipelines = [ - dict( - mode='test', - data=data['val'], - dist_eval=True, - evaluators=[dict(type='ClsEvaluator', topk=(1, 5))], - ) -] - -# additional hooks -custom_hooks = [] - # optimizer optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001) diff --git a/configs/classification/imagenet/swint/swin_tiny_patch4_window7_224_b64x16_300e_jpg.py b/configs/classification/imagenet/swint/swin_tiny_patch4_window7_224_b64x16_300e_jpg.py index b7381c5d..63ffdbf2 100644 --- a/configs/classification/imagenet/swint/swin_tiny_patch4_window7_224_b64x16_300e_jpg.py +++ b/configs/classification/imagenet/swint/swin_tiny_patch4_window7_224_b64x16_300e_jpg.py @@ -1,15 +1,9 @@ -_base_ = 'configs/base.py' - -log_config = dict( - interval=10, - hooks=[dict(type='TextLoggerHook'), - dict(type='TensorboardLoggerHook')]) +_base_ = '../common/dataset/imagenet_classification.py' # model settings model = dict( type='Classification', train_preprocess=['mixUp'], - pretrained=False, mixup_cfg=dict( mixup_alpha=0.8, cutmix_alpha=1.0, @@ -29,17 +23,11 @@ }, with_fc=False)) -data_train_list = 'data/imagenet_raw/meta/train_labeled.txt' -data_train_root = 'data/imagenet_raw/train/' -data_test_list = 'data/imagenet_raw/meta/val_labeled.txt' -data_test_root = 'data/imagenet_raw/val/' -data_all_list = 'data/imagenet_raw/meta/all_labeled.txt' -data_root = 'data/imagenet_raw/' - -dataset_type = 'ClsDataset' +image_size2 = 224 img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + train_pipeline = [ - dict(type='RandomResizedCrop', size=224), + dict(type='RandomResizedCrop', size=image_size2), dict(type='RandomHorizontalFlip'), dict( type='MMRandAugment', @@ -62,44 +50,11 @@ dict(type='Normalize', **img_norm_cfg), dict(type='Collect', keys=['img', 'gt_labels']) ] -test_pipeline = [ - dict(type='Resize', size=256), - dict(type='CenterCrop', size=224), - dict(type='ToTensor'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Collect', keys=['img', 'gt_labels']) -] data = dict( imgs_per_gpu=64, # total 256 workers_per_gpu=8, - train=dict( - type=dataset_type, - data_source=dict( - list_file=data_train_list, - root=data_train_root, - type='ClsSourceImageList'), - pipeline=train_pipeline), - val=dict( - type=dataset_type, - data_source=dict( - list_file=data_test_list, - root=data_test_root, - type='ClsSourceImageList'), - pipeline=test_pipeline)) - -eval_config = dict(initial=True, interval=1, gpu_collect=True) -eval_pipelines = [ - dict( - mode='test', - data=data['val'], - dist_eval=True, - evaluators=[dict(type='ClsEvaluator', topk=(1, 5))], - ) -] - -# additional hooks -custom_hooks = [] + train=dict(pipeline=train_pipeline)) # optimizer paramwise_options = { diff --git a/configs/classification/imagenet/vit/vit_base_patch16_224_b64x64_300e_jpg.py b/configs/classification/imagenet/vit/vit_base_patch16_224_b64x64_300e_jpg.py index 1b1ca32b..94e0d7ca 100644 --- a/configs/classification/imagenet/vit/vit_base_patch16_224_b64x64_300e_jpg.py +++ b/configs/classification/imagenet/vit/vit_base_patch16_224_b64x64_300e_jpg.py @@ -1,15 +1,9 @@ -_base_ = 'configs/base.py' - -log_config = dict( - interval=10, - hooks=[dict(type='TextLoggerHook'), - dict(type='TensorboardLoggerHook')]) +_base_ = '../common/dataset/imagenet_classification.py' # model settings model = dict( type='Classification', train_preprocess=['mixUp'], - pretrained=False, mixup_cfg=dict( mixup_alpha=0.2, prob=1.0, @@ -28,17 +22,12 @@ }, with_fc=False)) -data_train_list = 'data/imagenet_raw/meta/train_labeled.txt' -data_train_root = 'data/imagenet_raw/train/' -data_test_list = 'data/imagenet_raw/meta/val_labeled.txt' -data_test_root = 'data/imagenet_raw/val/' -data_all_list = 'data/imagenet_raw/meta/all_labeled.txt' -data_root = 'data/imagenet_raw/' - -dataset_type = 'ClsDataset' +image_size2 = 224 +image_size1 = int((256 / 224) * image_size2) img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + train_pipeline = [ - dict(type='RandomResizedCrop', size=224), + dict(type='RandomResizedCrop', size=image_size2), dict(type='RandomHorizontalFlip'), dict(type='MMAutoAugment'), dict(type='ToTensor'), @@ -46,8 +35,8 @@ dict(type='Collect', keys=['img', 'gt_labels']) ] test_pipeline = [ - dict(type='Resize', size=256), - dict(type='CenterCrop', size=224), + dict(type='Resize', size=image_size1), + dict(type='CenterCrop', size=image_size2), dict(type='ToTensor'), dict(type='Normalize', **img_norm_cfg), dict(type='Collect', keys=['img', 'gt_labels']) @@ -56,33 +45,8 @@ data = dict( imgs_per_gpu=64, # total 256 workers_per_gpu=8, - train=dict( - type=dataset_type, - data_source=dict( - list_file=data_train_list, - root=data_train_root, - type='ClsSourceImageList'), - pipeline=train_pipeline), - val=dict( - type=dataset_type, - data_source=dict( - list_file=data_test_list, - root=data_test_root, - type='ClsSourceImageList'), - pipeline=test_pipeline)) - -eval_config = dict(initial=True, interval=1, gpu_collect=True) -eval_pipelines = [ - dict( - mode='test', - data=data['val'], - dist_eval=True, - evaluators=[dict(type='ClsEvaluator', topk=(1, 5))], - ) -] - -# additional hooks -custom_hooks = [] + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline)) # optimizer optimizer = dict( diff --git a/configs/metric_learning/common/dataset/imagenet_metriclearning.py b/configs/metric_learning/common/dataset/imagenet_metriclearning.py new file mode 100644 index 00000000..1f6d5049 --- /dev/null +++ b/configs/metric_learning/common/dataset/imagenet_metriclearning.py @@ -0,0 +1,167 @@ +_base_ = '../metriclearning_base.py' + +class_list = [ + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', + '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', + '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', + '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', + '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', + '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', + '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', + '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', + '98', '99', '100', '101', '102', '103', '104', '105', '106', '107', '108', + '109', '110', '111', '112', '113', '114', '115', '116', '117', '118', + '119', '120', '121', '122', '123', '124', '125', '126', '127', '128', + '129', '130', '131', '132', '133', '134', '135', '136', '137', '138', + '139', '140', '141', '142', '143', '144', '145', '146', '147', '148', + '149', '150', '151', '152', '153', '154', '155', '156', '157', '158', + '159', '160', '161', '162', '163', '164', '165', '166', '167', '168', + '169', '170', '171', '172', '173', '174', '175', '176', '177', '178', + '179', '180', '181', '182', '183', '184', '185', '186', '187', '188', + '189', '190', '191', '192', '193', '194', '195', '196', '197', '198', + '199', '200', '201', '202', '203', '204', '205', '206', '207', '208', + '209', '210', '211', '212', '213', '214', '215', '216', '217', '218', + '219', '220', '221', '222', '223', '224', '225', '226', '227', '228', + '229', '230', '231', '232', '233', '234', '235', '236', '237', '238', + '239', '240', '241', '242', '243', '244', '245', '246', '247', '248', + '249', '250', '251', '252', '253', '254', '255', '256', '257', '258', + '259', '260', '261', '262', '263', '264', '265', '266', '267', '268', + '269', '270', '271', '272', '273', '274', '275', '276', '277', '278', + '279', '280', '281', '282', '283', '284', '285', '286', '287', '288', + '289', '290', '291', '292', '293', '294', '295', '296', '297', '298', + '299', '300', '301', '302', '303', '304', '305', '306', '307', '308', + '309', '310', '311', '312', '313', '314', '315', '316', '317', '318', + '319', '320', '321', '322', '323', '324', '325', '326', '327', '328', + '329', '330', '331', '332', '333', '334', '335', '336', '337', '338', + '339', '340', '341', '342', '343', '344', '345', '346', '347', '348', + '349', '350', '351', '352', '353', '354', '355', '356', '357', '358', + '359', '360', '361', '362', '363', '364', '365', '366', '367', '368', + '369', '370', '371', '372', '373', '374', '375', '376', '377', '378', + '379', '380', '381', '382', '383', '384', '385', '386', '387', '388', + '389', '390', '391', '392', '393', '394', '395', '396', '397', '398', + '399', '400', '401', '402', '403', '404', '405', '406', '407', '408', + '409', '410', '411', '412', '413', '414', '415', '416', '417', '418', + '419', '420', '421', '422', '423', '424', '425', '426', '427', '428', + '429', '430', '431', '432', '433', '434', '435', '436', '437', '438', + '439', '440', '441', '442', '443', '444', '445', '446', '447', '448', + '449', '450', '451', '452', '453', '454', '455', '456', '457', '458', + '459', '460', '461', '462', '463', '464', '465', '466', '467', '468', + '469', '470', '471', '472', '473', '474', '475', '476', '477', '478', + '479', '480', '481', '482', '483', '484', '485', '486', '487', '488', + '489', '490', '491', '492', '493', '494', '495', '496', '497', '498', + '499', '500', '501', '502', '503', '504', '505', '506', '507', '508', + '509', '510', '511', '512', '513', '514', '515', '516', '517', '518', + '519', '520', '521', '522', '523', '524', '525', '526', '527', '528', + '529', '530', '531', '532', '533', '534', '535', '536', '537', '538', + '539', '540', '541', '542', '543', '544', '545', '546', '547', '548', + '549', '550', '551', '552', '553', '554', '555', '556', '557', '558', + '559', '560', '561', '562', '563', '564', '565', '566', '567', '568', + '569', '570', '571', '572', '573', '574', '575', '576', '577', '578', + '579', '580', '581', '582', '583', '584', '585', '586', '587', '588', + '589', '590', '591', '592', '593', '594', '595', '596', '597', '598', + '599', '600', '601', '602', '603', '604', '605', '606', '607', '608', + '609', '610', '611', '612', '613', '614', '615', '616', '617', '618', + '619', '620', '621', '622', '623', '624', '625', '626', '627', '628', + '629', '630', '631', '632', '633', '634', '635', '636', '637', '638', + '639', '640', '641', '642', '643', '644', '645', '646', '647', '648', + '649', '650', '651', '652', '653', '654', '655', '656', '657', '658', + '659', '660', '661', '662', '663', '664', '665', '666', '667', '668', + '669', '670', '671', '672', '673', '674', '675', '676', '677', '678', + '679', '680', '681', '682', '683', '684', '685', '686', '687', '688', + '689', '690', '691', '692', '693', '694', '695', '696', '697', '698', + '699', '700', '701', '702', '703', '704', '705', '706', '707', '708', + '709', '710', '711', '712', '713', '714', '715', '716', '717', '718', + '719', '720', '721', '722', '723', '724', '725', '726', '727', '728', + '729', '730', '731', '732', '733', '734', '735', '736', '737', '738', + '739', '740', '741', '742', '743', '744', '745', '746', '747', '748', + '749', '750', '751', '752', '753', '754', '755', '756', '757', '758', + '759', '760', '761', '762', '763', '764', '765', '766', '767', '768', + '769', '770', '771', '772', '773', '774', '775', '776', '777', '778', + '779', '780', '781', '782', '783', '784', '785', '786', '787', '788', + '789', '790', '791', '792', '793', '794', '795', '796', '797', '798', + '799', '800', '801', '802', '803', '804', '805', '806', '807', '808', + '809', '810', '811', '812', '813', '814', '815', '816', '817', '818', + '819', '820', '821', '822', '823', '824', '825', '826', '827', '828', + '829', '830', '831', '832', '833', '834', '835', '836', '837', '838', + '839', '840', '841', '842', '843', '844', '845', '846', '847', '848', + '849', '850', '851', '852', '853', '854', '855', '856', '857', '858', + '859', '860', '861', '862', '863', '864', '865', '866', '867', '868', + '869', '870', '871', '872', '873', '874', '875', '876', '877', '878', + '879', '880', '881', '882', '883', '884', '885', '886', '887', '888', + '889', '890', '891', '892', '893', '894', '895', '896', '897', '898', + '899', '900', '901', '902', '903', '904', '905', '906', '907', '908', + '909', '910', '911', '912', '913', '914', '915', '916', '917', '918', + '919', '920', '921', '922', '923', '924', '925', '926', '927', '928', + '929', '930', '931', '932', '933', '934', '935', '936', '937', '938', + '939', '940', '941', '942', '943', '944', '945', '946', '947', '948', + '949', '950', '951', '952', '953', '954', '955', '956', '957', '958', + '959', '960', '961', '962', '963', '964', '965', '966', '967', '968', + '969', '970', '971', '972', '973', '974', '975', '976', '977', '978', + '979', '980', '981', '982', '983', '984', '985', '986', '987', '988', + '989', '990', '991', '992', '993', '994', '995', '996', '997', '998', '999' +] + +data_source_type = 'ClsSourceImageList' +data_train_list = '/apsarapangu/disk1/yunji.cjy/data/imagenet_raw/meta/train_labeled.txt' +data_train_root = '/apsarapangu/disk1/yunji.cjy/data/imagenet_raw/train/' +data_test_list = '/apsarapangu/disk1/yunji.cjy/data/imagenet_raw/meta/val_labeled.txt' +data_test_root = '/apsarapangu/disk1/yunji.cjy/data/imagenet_raw/validation/' +dataset_type = 'ClsDataset' +img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +image_size2 = 224 +image_size1 = int((256 / 224) * image_size2) + +train_pipeline = [ + dict(type='RandomResizedCrop', size=image_size2), + dict(type='RandomHorizontalFlip'), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Collect', keys=['img', 'gt_labels']) +] +test_pipeline = [ + dict(type='Resize', size=image_size1), + dict(type='CenterCrop', size=image_size2), + dict(type='ToTensor'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Collect', keys=['img', 'gt_labels']) +] + +data = dict( + imgs_per_gpu=32, # total 256 + workers_per_gpu=4, + drop_last=True, + train=dict( + type=dataset_type, + data_source=dict( + list_file=data_train_list, + root=data_train_root, + type=data_source_type), + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_source=dict( + list_file=data_test_list, + root=data_test_root, + type=data_source_type), + pipeline=test_pipeline)) + +eval_config = dict(initial=False, interval=1, gpu_collect=True) +eval_pipelines = [ + dict( + mode='test', + data=data['val'], + dist_eval=True, + evaluators=[dict(type='ClsEvaluator', topk=(1, 5))], + ), + dict( + mode='extract', + dist_eval=True, + data=data['val'], + evaluators=[ + dict( + type='RetrivalTopKEvaluator', + topk=(1, 2, 4, 8), + metric_names=('R@K=1', 'R@K=8')) + ], + ) +] diff --git a/configs/metric_learning/common/metriclearning_base.py b/configs/metric_learning/common/metriclearning_base.py new file mode 100644 index 00000000..0cc5fb74 --- /dev/null +++ b/configs/metric_learning/common/metriclearning_base.py @@ -0,0 +1,6 @@ +_base_ = 'configs/base.py' + +log_config = dict( + interval=10, + hooks=[dict(type='TextLoggerHook'), + dict(type='TensorboardLoggerHook')]) diff --git a/configs/metric_learning/imagenet_timm_modelparallel_softmaxbased_jpg.py b/configs/metric_learning/imagenet_timm_modelparallel_softmaxbased_jpg.py new file mode 100644 index 00000000..c6c3f4d2 --- /dev/null +++ b/configs/metric_learning/imagenet_timm_modelparallel_softmaxbased_jpg.py @@ -0,0 +1,39 @@ +_base_ = './imagenet_timm_softmaxbased_jpg.py' + +backbone_channels = 2048 +feature_channels = 1536 +num_classes = 300 + +metric_loss_name = 'ModelParallelSoftmaxLoss' +metric_loss_scale = 30 +metric_loss_margin = 0.4 + +# model settings +model = dict( + _delete_=True, + type='Classification', + backbone=dict(type='PytorchImageModelWrapper', model_name='resnet50'), + neck=dict( + type='RetrivalNeck', + in_channels=backbone_channels, + out_channels=feature_channels, + with_avg_pool=True, + cdg_config=['G', 'S']), + head=[ + dict( + type='MpMetrixHead', + with_avg_pool=True, + in_channels=feature_channels, + loss_config=[ + dict( + type=metric_loss_name, + loss_weight=1.0, + norm=False, + ddp=True, + scale=30, + margin=0.4, + embedding_size=feature_channels, + num_classes=num_classes) + ], + input_feature_index=[0]) + ]) diff --git a/configs/metric_learning/imagenet_timm_softmaxbased_jpg.py b/configs/metric_learning/imagenet_timm_softmaxbased_jpg.py new file mode 100644 index 00000000..62335a4a --- /dev/null +++ b/configs/metric_learning/imagenet_timm_softmaxbased_jpg.py @@ -0,0 +1,72 @@ +_base_ = 'common/dataset/imagenet_metriclearning.py' + +backbone_channels = 2048 +feature_channels = 1536 +num_classes = 300 + +metric_loss_name = 'AMSoftmaxLoss' +metric_loss_scale = 30 +metric_loss_margin = 0.4 + +# model settings +model = dict( + type='Classification', + backbone=dict(type='PytorchImageModelWrapper', model_name='resnet50'), + neck=dict( + type='RetrivalNeck', + in_channels=backbone_channels, + out_channels=feature_channels, + with_avg_pool=True, + cdg_config=['G', 'S']), + head=[ + dict( + type='MpMetrixHead', + with_avg_pool=True, + in_channels=feature_channels, + loss_config=[ + dict( + type='CrossEntropyLossWithLabelSmooth', + loss_weight=1.0, + norm=True, + ddp=False, + label_smooth=0.1, + temperature=0.05, + with_cls=True, + embedding_size=feature_channels, + num_classes=num_classes) + ], + input_feature_index=[1]), + dict( + type='MpMetrixHead', + with_avg_pool=True, + in_channels=feature_channels, + loss_config=[ + dict( + type=metric_loss_name, + loss_weight=1.0, + norm=False, + ddp=False, + scale=metric_loss_scale, + margin=metric_loss_margin, + embedding_size=feature_channels, + num_classes=num_classes) + ], + input_feature_index=[0]) + ]) + +# optimizer +optimizer = dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0001) + +# learning policy +lr_config = dict( + policy='CosineAnnealing', + min_lr=1e-6, + warmup='linear', + warmup_iters=100, + warmup_ratio=0.0001) + +checkpoint_config = dict(interval=5) +# runtime settings +total_epochs = 100 + +find_unused_parameters = True diff --git a/easycv/apis/export.py b/easycv/apis/export.py index 6c380008..0f6c1085 100644 --- a/easycv/apis/export.py +++ b/easycv/apis/export.py @@ -34,6 +34,12 @@ def export(cfg, ckpt_path, filename, **kwargs): ckpt_path (str): path to checkpoint file filename (str): filename to save exported models """ + + logging.warning( + 'Export needs to set pretrained to false to avoid hanging during distributed training' + ) + cfg.model['pretrained'] = False + model = build_model(cfg.model) if ckpt_path != 'dummy': load_checkpoint(model, ckpt_path, map_location='cpu') diff --git a/easycv/core/evaluation/classification_eval.py b/easycv/core/evaluation/classification_eval.py index c60c6ede..5b469d43 100644 --- a/easycv/core/evaluation/classification_eval.py +++ b/easycv/core/evaluation/classification_eval.py @@ -66,6 +66,11 @@ def _evaluate_impl(self, predictions, gt_labels): num = scores.size(0) _, pred = scores.topk( max(self._topk), dim=1, largest=True, sorted=True) + + # Avoid topk values greater than the number of categories + self._topk = np.array(list(self._topk)) + self._topk = np.clip(self._topk, 1, scores.shape[-1]) + pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) # KxN for k in self._topk: diff --git a/easycv/datasets/classification/data_sources/__init__.py b/easycv/datasets/classification/data_sources/__init__.py index 1427be6f..b3b9f7d9 100644 --- a/easycv/datasets/classification/data_sources/__init__.py +++ b/easycv/datasets/classification/data_sources/__init__.py @@ -2,12 +2,12 @@ from .cifar import ClsSourceCifar10, ClsSourceCifar100 from .class_list import ClsSourceImageListByClass from .cub import ClsSourceCUB -from .image_list import ClsSourceImageList +from .image_list import ClsSourceImageList, ClsSourceItag from .imagenet import ClsSourceImageNet1k from .imagenet_tfrecord import ClsSourceImageNetTFRecord __all__ = [ 'ClsSourceCifar10', 'ClsSourceCifar100', 'ClsSourceImageListByClass', - 'ClsSourceImageList', 'ClsSourceImageNetTFRecord', 'ClsSourceCUB', - 'ClsSourceImageNet1k' + 'ClsSourceImageList', 'ClsSourceItag', 'ClsSourceImageNetTFRecord', + 'ClsSourceCUB', 'ClsSourceImageNet1k' ] diff --git a/easycv/datasets/classification/data_sources/image_list.py b/easycv/datasets/classification/data_sources/image_list.py index e37f9fa8..00d34a3a 100644 --- a/easycv/datasets/classification/data_sources/image_list.py +++ b/easycv/datasets/classification/data_sources/image_list.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import json import logging import os import time @@ -89,7 +90,10 @@ def parse_list_file(list_file, root, delimeter): for l in lines: splits = l.strip().split(delimeter) - fns.append(os.path.join(root, splits[0])) + if len(root) > 0: + fns.append(os.path.join(root, splits[0])) + else: + fns.append(splits[0]) # must be int,other with mmcv collect will crash label = [int(i) for i in splits[1:]] labels.append( @@ -124,3 +128,118 @@ def __getitem__(self, idx): result_dict = {'img': img, 'gt_labels': label} return result_dict + + +@DATASOURCES.register_module +class ClsSourceItag(ClsSourceImageList): + """ data source itag for classification + + Args: + list_file : str / list(str), str means a input image list file path, + this file contains records as `image_path label` in list_file + list(str) means multi image list, each one contains some records as `image_path label` + root: str / list(str), root path for image_path, each list_file will need a root, + if len(root) < len(list_file), we will use root[-1] to fill root list. + delimeter: str, delimeter of each line in the `list_file` + split_huge_listfile_byrank: Adapt to the situation that the memory cannot fully load a huge amount of data list. + If split, data list will be split to each rank. + split_label_balance: if `split_huge_listfile_byrank` is true, whether split with label balance + cache_path: if `split_huge_listfile_byrank` is true, cache list_file will be saved to cache_path. + max_try: int, max try numbers of reading image + """ + + def __init__(self, + list_file, + root='', + delimeter=' ', + split_huge_listfile_byrank=False, + split_label_balance=False, + cache_path='data/', + max_try=20): + + ImageFile.LOAD_TRUNCATED_IMAGES = True + + self.max_try = max_try + + # DistributedMPSampler need this attr + self.has_labels = True + + if isinstance(list_file, str): + assert isinstance(root, str), 'list_file is str, root must be str' + list_file = [list_file] + root = [root] + else: + assert isinstance(list_file, list), \ + 'list_file should be str or list(str)' + root = [root] if isinstance(root, str) else root + if not isinstance(root, list): + raise TypeError('root must be str or list(str), but get %s' % + type(root)) + + if len(root) < len(list_file): + logging.warning( + 'len(root) < len(list_file), fill root with root last!') + root = root + [root[-1]] * (len(list_file) - len(root)) + + # TODO: support return list, donot save split file + # TODO: support loading list_file that have already been split + if split_huge_listfile_byrank: + with dist_zero_exec(): + list_file = split_listfile_byrank( + list_file=list_file, + label_balance=split_label_balance, + save_path=cache_path) + + self.fns = [] + self.labels = [] + label_dict = dict() + for l, r in zip(list_file, root): + fns, labels, label_dict = self.parse_list_file(l, label_dict) + self.fns += fns + self.labels += labels + + @staticmethod + def parse_list_file(list_file, label_dict): + with open(list_file, 'r', encoding='utf-8') as f: + data = f.readlines() + + fns = [] + labels = [] + for i in range(len(data)): + data_i = json.loads(data[i]) + img_path = data_i['data']['source'] + label = [] + + priority = 2 + for k in data_i.keys(): + if 'verify' in k: + priority = 0 + break + elif 'check' in k: + priority = 1 + + for k, v in data_i.items(): + if 'label' in k: + label = [] + result_list = v['results'] + for j in range(len(result_list)): + anno_list = result_list[j]['data'] + if 'labels' in anno_list: + if anno_list['labels']['单选'] not in label_dict: + label_dict[anno_list['labels']['单选']] = len( + label_dict) + label.append(label_dict[anno_list['labels']['单选']]) + else: + if anno_list not in label_dict: + label_dict[anno_list] = len(label_dict) + label.append(label_dict[anno_list]) + if 'verify' in k: + break + elif 'check' in k and priority == 1: + break + + fns.append(img_path) + labels.append( + label[0]) if len(label) == 1 else labels.append(label) + + return fns, labels, label_dict diff --git a/easycv/models/classification/classification.py b/easycv/models/classification/classification.py index ccd30d50..57c32c08 100644 --- a/easycv/models/classification/classification.py +++ b/easycv/models/classification/classification.py @@ -120,7 +120,10 @@ def init_weights(self): self.backbone, self.backbone.default_pretrained_model_path, strict=False, - logger=logger) + logger=logger, + revise_keys=[ + (r'^backbone\.', '') + ]) # revise_keys is used to avoid load mismatches else: raise ValueError( 'default_pretrained_model_path for {} not found'.format( diff --git a/easycv/models/modelzoo.py b/easycv/models/modelzoo.py index 58f005c4..d367cc62 100644 --- a/easycv/models/modelzoo.py +++ b/easycv/models/modelzoo.py @@ -40,6 +40,25 @@ 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw64/epoch_100.pth', } +vit = { + 'vit-base': + 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/vit/vit-base-patch16/epoch_300.pth', +} + +swint = { + 'swint-tiny': + 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/swint/swin-tiny-patch4-window7/epoch_300.pth', +} + +deit = { + 'deitiii-small': + 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/deitiii/imagenet_deitiii_small_patch16_224/deitiii_small.pth', + 'deitiii-base': + 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/deitiii/imagenet_deitiii_base_patch16_192/deitiii_base.pth', + 'deitiii-large': + 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/deitiii/imagenet_deitiii_large_patch16_192/deitiii_large.pth', +} + mobilenetv2 = { 'MobileNetV2_1.0': 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/mobilenetv2/mobilenet_v2.pth', diff --git a/easycv/models/utils/multi_pooling.py b/easycv/models/utils/multi_pooling.py index 6ca1be89..75304d9a 100644 --- a/easycv/models/utils/multi_pooling.py +++ b/easycv/models/utils/multi_pooling.py @@ -23,10 +23,6 @@ def gem(self, x, p=3, eps=1e-6): return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p) - def __repr__(self): - return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format( - self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' - class MultiPooling(nn.Module): """Pooling layers for features from multiple depth. diff --git a/easycv/toolkit/hpo/det/config_dlc.ini b/easycv/toolkit/hpo/det/config_dlc.ini index dd2b2a9c..188d6b2b 100644 --- a/easycv/toolkit/hpo/det/config_dlc.ini +++ b/easycv/toolkit/hpo/det/config_dlc.ini @@ -9,7 +9,7 @@ cmd2="dlc submit pytorch --name=test_nni_${exp_id}_${trial_id} \ --data_sources='d-domlyt834bngpr68iu' \ --worker_image=registry-vpc.cn-shanghai.aliyuncs.com/mybigpai/nni:0.0.3 \ --command='cd /mnt/data/EasyCV && pip install mmcv-full && pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple \ - && CUDA_VISIBLE_DEVICES=0,1,2,3,4 python -m torch.distributed.launch --nproc_per_node=4 --master_port=29400 tools/train.py easycv/toolkit/hpo/det/fcos_r50_torch_1x_coco.py --work_dir easycv/toolkit/hpo/det/model/model_${exp_id}_${trial_id} --launcher pytorch --seed 42 --deterministic --user_config_params --data_root /mnt/data/coco/ --data.imgs_per_gpu ${batch_size} --optimizer.lr ${lr} ' \ + && CUDA_VISIBLE_DEVICES=0,1,2,3,4 python -m torch.distributed.launch --nproc_per_node=4 --master_port=29400 tools/train.py configs/detection/fcos/fcos_r50_torch_1x_coco.py --work_dir easycv/toolkit/hpo/det/model/model_${exp_id}_${trial_id} --launcher pytorch --seed 42 --deterministic --user_config_params data_root='/mnt/data/coco/' data.imgs_per_gpu=${batch_size} optimizer.lr=${lr} ' \ --workspace_id='255705' " [metric_config] diff --git a/easycv/toolkit/hpo/det/config_local.ini b/easycv/toolkit/hpo/det/config_local.ini index 189a87ea..1f524f80 100644 --- a/easycv/toolkit/hpo/det/config_local.ini +++ b/easycv/toolkit/hpo/det/config_local.ini @@ -1,5 +1,5 @@ [cmd_config] -cmd1='cd /mnt/data/EasyCV && CUDA_VISIBLE_DEVICES=0,1,2,3,4 python -m torch.distributed.launch --nproc_per_node=4 --master_port=29400 tools/train.py easycv/toolkit/hpo/det/fcos_r50_torch_1x_coco.py --work_dir easycv/toolkit/hpo/det/model/model_${exp_id}_${trial_id} --launcher pytorch --seed 42 --deterministic --user_config_params --data_root /mnt/data/coco/ --data.imgs_per_gpu ${batch_size} --optimizer.lr ${lr} ' +cmd1='cd /mnt/data/EasyCV && CUDA_VISIBLE_DEVICES=0,1,2,3,4 python -m torch.distributed.launch --nproc_per_node=4 --master_port=29400 tools/train.py configs/detection/fcos/fcos_r50_torch_1x_coco.py --work_dir easycv/toolkit/hpo/det/model/model_${exp_id}_${trial_id} --launcher pytorch --seed 42 --deterministic --user_config_params data_root='/mnt/data/coco/' data.imgs_per_gpu=${batch_size} optimizer.lr=${lr} ' [metric_config] metric_filepath=easycv/toolkit/hpo/det/model/model_${exp_id}_${trial_id}/tf_logs diff --git a/easycv/toolkit/hpo/det/fcos_r50_torch_1x_coco.py b/easycv/toolkit/hpo/det/fcos_r50_torch_1x_coco.py deleted file mode 100644 index 402b11d9..00000000 --- a/easycv/toolkit/hpo/det/fcos_r50_torch_1x_coco.py +++ /dev/null @@ -1,192 +0,0 @@ -train_cfg = {} -test_cfg = {} -optimizer_config = dict() # grad_clip, coalesce, bucket_size_mb -# yapf:disable -log_config = dict( - interval=50, - hooks=[ - dict(type='TextLoggerHook'), - dict(type='TensorboardLoggerHook') - ]) -# yapf:enable -# runtime settings -dist_params = dict(backend='nccl') -cudnn_benchmark = False -log_level = 'INFO' -load_from = None -resume_from = None -workflow = [('train', 1)] - -CLASSES = [ - 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', - 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', - 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', - 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', - 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', - 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', - 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', - 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', - 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', - 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', - 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', - 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', - 'hair drier', 'toothbrush' -] - -# dataset settings -data_root = '/mnt/data/coco/' -img_norm_cfg = dict( - mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) - -train_pipeline = [ - dict(type='MMResize', img_scale=(1333, 800), keep_ratio=True), - dict(type='MMRandomFlip', flip_ratio=0.5), - dict(type='MMNormalize', **img_norm_cfg), - dict(type='MMPad', size_divisor=32), - dict(type='DefaultFormatBundle'), - dict( - type='Collect', - keys=['img', 'gt_bboxes', 'gt_labels'], - meta_keys=('filename', 'ori_filename', 'ori_shape', 'ori_img_shape', - 'img_shape', 'pad_shape', 'scale_factor', 'flip', - 'flip_direction', 'img_norm_cfg')) -] -test_pipeline = [ - dict( - type='MMMultiScaleFlipAug', - img_scale=(1333, 800), - flip=False, - transforms=[ - dict(type='MMResize', keep_ratio=True), - dict(type='MMRandomFlip'), - dict(type='MMNormalize', **img_norm_cfg), - dict(type='MMPad', size_divisor=32), - dict(type='ImageToTensor', keys=['img']), - dict( - type='Collect', - keys=['img'], - meta_keys=('filename', 'ori_filename', 'ori_shape', - 'ori_img_shape', 'img_shape', 'pad_shape', - 'scale_factor', 'flip', 'flip_direction', - 'img_norm_cfg')) - ]) -] - -train_dataset = dict( - type='DetDataset', - data_source=dict( - type='DetSourceCoco', - ann_file='${data_root}' + 'annotations/instances_train2017.json', - img_prefix='${data_root}' + 'train2017/', - pipeline=[ - dict(type='LoadImageFromFile'), - dict(type='LoadAnnotations', with_bbox=True) - ], - classes=CLASSES, - test_mode=False, - filter_empty_gt=True, - iscrowd=False), - pipeline=train_pipeline) - -val_dataset = dict( - type='DetDataset', - imgs_per_gpu=1, - data_source=dict( - type='DetSourceCoco', - ann_file='${data_root}' + 'annotations/instances_val2017.json', - img_prefix='${data_root}' + 'val2017/', - pipeline=[ - dict(type='LoadImageFromFile'), - dict(type='LoadAnnotations', with_bbox=True) - ], - classes=CLASSES, - test_mode=True, - filter_empty_gt=False, - iscrowd=True), - pipeline=test_pipeline) - -data = dict( - imgs_per_gpu=2, workers_per_gpu=2, train=train_dataset, val=val_dataset) - -# evaluation -eval_config = dict(interval=1, gpu_collect=False) -eval_pipelines = [ - dict( - mode='test', - evaluators=[ - dict(type='CocoDetectionEvaluator', classes=CLASSES), - ], - ) -] - -# model settings -model = dict( - type='Detection', - pretrained=True, - backbone=dict( - type='ResNet', - depth=50, - num_stages=4, - out_indices=(1, 2, 3, 4), - frozen_stages=1, - norm_cfg=dict(type='BN', requires_grad=False), - norm_eval=True, - style='pytorch'), - neck=dict( - type='FPN', - in_channels=[256, 512, 1024, 2048], - out_channels=256, - start_level=1, - add_extra_convs='on_output', # use P5 - num_outs=5, - relu_before_extra_convs=True), - head=dict( - type='FCOSHead', - num_classes=80, - in_channels=256, - stacked_convs=4, - feat_channels=256, - strides=[8, 16, 32, 64, 128], - center_sampling=True, - center_sample_radius=1.5, - norm_on_bbox=True, - centerness_on_reg=True, - conv_cfg=None, - loss_cls=dict( - type='FocalLoss', - use_sigmoid=True, - gamma=2.0, - alpha=0.25, - loss_weight=1.0), - loss_bbox=dict(type='GIoULoss', loss_weight=1.0), - loss_centerness=dict( - type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), - norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), - conv_bias=True, - test_cfg=dict( - nms_pre=1000, - min_bbox_size=0, - score_thr=0.05, - nms=dict(type='nms', iou_threshold=0.6), - max_per_img=100))) - -checkpoint_config = dict(interval=10) -# optimizer -optimizer = dict( - type='SGD', - lr=0.01, - momentum=0.9, - weight_decay=0.0001, - paramwise_options=dict(bias_lr_mult=2., bias_decay_mult=0.)) -optimizer_config = dict(grad_clip=None) -# learning policy -lr_config = dict( - policy='step', - warmup='linear', - warmup_iters=500, - warmup_ratio=1.0 / 3, - step=[8, 11]) - -total_epochs = 12 - -find_unused_parameters = False diff --git a/easycv/utils/checkpoint.py b/easycv/utils/checkpoint.py index 4c987c83..c674eba5 100644 --- a/easycv/utils/checkpoint.py +++ b/easycv/utils/checkpoint.py @@ -18,7 +18,8 @@ def load_checkpoint(model, filename, map_location='cpu', strict=False, - logger=None): + logger=None, + revise_keys=[(r'^module\.', '')]): """Load checkpoint from a file or URI. Args: @@ -30,6 +31,10 @@ def load_checkpoint(model, strict (bool): Whether to allow different params for the model and checkpoint. logger (:mod:`logging.Logger` or None): The logger for error message. + revise_keys (list): A list of customized keywords to modify the + state_dict in checkpoint. Each item is a (pattern, replacement) + pair of the regular expression operations. Default: strip + the prefix 'module.' by [(r'^module\\.', '')]. Returns: dict or OrderedDict: The loaded checkpoint. @@ -62,12 +67,14 @@ def load_checkpoint(model, ) and torch.distributed.is_initialized(): torch.distributed.barrier() filename = cache_file + return mmcv_load_checkpoint( model, filename, map_location=map_location, strict=strict, - logger=logger) + logger=logger, + revise_keys=revise_keys) def save_checkpoint(model, filename, optimizer=None, meta=None): diff --git a/easycv/utils/config_tools.py b/easycv/utils/config_tools.py index 90386c89..85fa452b 100644 --- a/easycv/utils/config_tools.py +++ b/easycv/utils/config_tools.py @@ -7,6 +7,8 @@ from mmcv import Config, import_modules_from_strings +import easycv +from easycv.file import io from easycv.framework.errors import IOError, KeyError, ValueError from .user_config_params_utils import check_value_type @@ -28,50 +30,130 @@ def traverse_replace(d, key, value): traverse_replace(v, key, value) -BASE_KEY = '_base_' - - -# To find base cfg in 'easycv/configs/', base_cfg_name should be 'configs/xx/xx.py' -# TODO: reset the api, keep the same way as mmcv `Config.fromfile` -def check_base_cfg_path(base_cfg_name='configs/base.py', ori_filename=None): - - if base_cfg_name == '../../base.py': - # To becompatible with previous config - base_cfg_name = 'configs/base.py' +class WrapperConfig(Config): + """A facility for config and config files. + + It supports common file formats as configs: python/json/yaml. The interface + is the same as a dict object and also allows access config values as + attributes. + + Example: + >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) + >>> cfg.a + 1 + >>> cfg.b + {'b1': [0, 1]} + >>> cfg.b.b1 + [0, 1] + >>> cfg = Config.fromfile('tests/data/config/a.py') + >>> cfg.filename + "/home/kchen/projects/mmcv/tests/data/config/a.py" + >>> cfg.item4 + 'test' + >>> cfg + "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: " + "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" + """ - base_cfg_dir_1 = osp.abspath(osp.dirname( - osp.dirname(__file__))) # easycv_package_root_path - base_cfg_path_1 = osp.join(base_cfg_dir_1, base_cfg_name) - print('Read base config from', base_cfg_path_1) - if osp.exists(base_cfg_path_1): - return base_cfg_path_1 + @staticmethod + def _substitute_predefined_vars(filename, + temp_config_name, + first_order_params=None): + """ + Override Config._substitute_predefined_vars. + Supports first-order parameter reuse to avoid rebuilding custom config.py templates. + + Args: + filename (str): Original script file. + temp_config_name (str): Template script file. + first_order_params (dict): first-order parameters. + + Returns: + No return value. + + """ + file_dirname = osp.dirname(filename) + file_basename = osp.basename(filename) + file_basename_no_extension = osp.splitext(file_basename)[0] + file_extname = osp.splitext(filename)[1] + support_templates = dict( + fileDirname=file_dirname, + fileBasename=file_basename, + fileBasenameNoExtension=file_basename_no_extension, + fileExtname=file_extname) + with open(filename, encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + left_match, right_match = '{([', '])}' + match_list = [] + line_list = [] + for line in f: + # Push and pop control regular item matching + match_length_before = len(match_list) + for single_str in line: + if single_str in left_match: + match_list.append(single_str) + if single_str in right_match: + match_list.pop() + match_length_after = len(match_list) + + key = line.split('=')[0].strip() + # Check whether it is a first-order parameter + if match_length_before == match_length_after == 0 and first_order_params and key in first_order_params: + value = first_order_params[key] + # repr() is used to convert the data into a string form (in the form of a Python expression) suitable for the interpreter to read + line = ' '.join([key, '=', repr(value)]) + '\n' + + line_list.append(line) + config_file = ''.join(line_list) + + for key, value in support_templates.items(): + regexp = r'\{\{\s*' + str(key) + r'\s*\}\}' + value = value.replace('\\', '/') + config_file = re.sub(regexp, value, config_file) + with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file: + tmp_config_file.write(config_file) + + +def check_base_cfg_path(base_cfg_name='configs/base.py', + father_cfg_name=None, + easycv_root=None): + """ + Concatenate paths by parsing path rules. - base_cfg_dir_2 = osp.dirname(base_cfg_dir_1) # upper level dir - base_cfg_path_2 = osp.join(base_cfg_dir_2, base_cfg_name) - print('Read base config from', base_cfg_path_2) - if osp.exists(base_cfg_path_2): - return base_cfg_path_2 + for example(pseudo-code): + 1. 'configs' in base_cfg_name or 'benchmarks' in base_cfg_name: + base_cfg_name = easycv_root + base_cfg_name - # relative to ori_filename - ori_cfg_dir = osp.dirname(ori_filename) - base_cfg_path_3 = osp.join(ori_cfg_dir, base_cfg_name) - base_cfg_path_3 = osp.abspath(osp.expanduser(base_cfg_path_3)) - if osp.exists(base_cfg_path_3): - return base_cfg_path_3 + 2. 'configs' not in base_cfg_name and 'benchmarks' not in base_cfg_name: + base_cfg_name = father_cfg_name + base_cfg_name - raise ValueError('%s not Found' % base_cfg_name) + """ + parse_base_cfg = base_cfg_name.split('/') + if parse_base_cfg[0] == 'configs' or parse_base_cfg[0] == 'benchmarks': + if easycv_root is not None: + base_cfg_name = osp.join(easycv_root, base_cfg_name) + else: + if father_cfg_name is not None: + _parse_base_path_list = base_cfg_name.split('/') + parse_base_path_list = copy.deepcopy(_parse_base_path_list) + parse_ori_path_list = father_cfg_name.split('/') + parse_ori_path_list.pop() + for filename in _parse_base_path_list: + if filename == '.': + parse_base_path_list.pop(0) + elif filename == '..': + parse_base_path_list.pop(0) + parse_ori_path_list.pop() + else: + break + base_cfg_name = '/'.join(parse_ori_path_list + + parse_base_path_list) + + return base_cfg_name # Read config without __base__ -def mmcv_file2dict_raw(ori_filename): - filename = osp.abspath(osp.expanduser(ori_filename)) - if not osp.isfile(filename): - if ori_filename.startswith('configs/'): - # read configs/config_templates/detection_oss.py - filename = check_base_cfg_path(ori_filename) - else: - raise ValueError('%s and %s not Found' % (ori_filename, filename)) - +def mmcv_file2dict_raw(filename, first_order_params=None): fileExtname = osp.splitext(filename)[1] if fileExtname not in ['.py', '.json', '.yaml', '.yml']: raise IOError('Only py/yml/yaml/json type are supported now!') @@ -82,7 +164,12 @@ def mmcv_file2dict_raw(ori_filename): if platform.system() == 'Windows': temp_config_file.close() temp_config_name = osp.basename(temp_config_file.name) - Config._substitute_predefined_vars(filename, temp_config_file.name) + if first_order_params is not None: + WrapperConfig._substitute_predefined_vars(filename, + temp_config_file.name, + first_order_params) + else: + Config._substitute_predefined_vars(filename, temp_config_file.name) if filename.endswith('.py'): temp_module_name = osp.splitext(temp_config_name)[0] sys.path.insert(0, temp_config_dir) @@ -110,11 +197,13 @@ def mmcv_file2dict_raw(ori_filename): # Reac config with __base__ -def mmcv_file2dict_base(ori_filename): - cfg_dict, cfg_text = mmcv_file2dict_raw(ori_filename) +def mmcv_file2dict_base(ori_filename, + first_order_params=None, + easycv_root=None): + cfg_dict, cfg_text = mmcv_file2dict_raw(ori_filename, first_order_params) + BASE_KEY = '_base_' if BASE_KEY in cfg_dict: - # cfg_dir = osp.dirname(filename) base_filename = cfg_dict.pop(BASE_KEY) base_filename = base_filename if isinstance(base_filename, list) else [base_filename] @@ -122,8 +211,10 @@ def mmcv_file2dict_base(ori_filename): cfg_dict_list = list() cfg_text_list = list() for f in base_filename: - base_cfg_path = check_base_cfg_path(f, ori_filename) - _cfg_dict, _cfg_text = mmcv_file2dict_base(base_cfg_path) + base_cfg_path = check_base_cfg_path( + f, ori_filename, easycv_root=easycv_root) + _cfg_dict, _cfg_text = mmcv_file2dict_base(base_cfg_path, + first_order_params) cfg_dict_list.append(_cfg_dict) cfg_text_list.append(_cfg_text) @@ -143,10 +234,80 @@ def mmcv_file2dict_base(ori_filename): return cfg_dict, cfg_text +def grouping_params(user_config_params): + first_order_params, multi_order_params = {}, {} + for full_key, v in user_config_params.items(): + key_list = full_key.split('.') + if len(key_list) == 1: + first_order_params[full_key] = v + else: + multi_order_params[full_key] = v + + return first_order_params, multi_order_params + + +def adapt_pai_params(cfg_dict, class_list_params=None): + """ + The user passes in the class_list_params. + + Args: + cfg_dict (dict): All parameters of cfg. + class_list_params (list): class_list_params[1] is num_classes. + class_list_params[0] supports three ways to build parameters. + str(.txt) parameter construction method: 0, 1, 2 or 0, \n, 1, \n, 2\n or 0, \n, 1, 2 or person, dog, cat. + list parameter construction method: '[0, 1, 2]' or '[person, dog, cat]' + '' parameter construction method: The default setting is str(0) - str(num_classes - 1) + + Returns: + cfg_dict (dict): Add the cfg of export and oss. + + """ + if class_list_params is not None: + class_list, num_classes = class_list_params[0], class_list_params[1] + if '.txt' in class_list: + cfg_dict['class_list'] = [] + with open(class_list, 'r', encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + lines = f.readlines() + for line in lines: + line = line.strip().strip(',').replace(' ', '').split(',') + cfg_dict['class_list'].extend(line) + elif len(class_list) > 0: + cfg_dict['class_list'] = list(map(str, class_list)) + else: + cfg_dict['class_list'] = list(map(str, range(0, num_classes))) + + # export config + cfg_dict['export'] = dict(export_neck=True) + cfg_dict['checkpoint_sync_export'] = True + # oss config + cfg_dict['oss_sync_config'] = dict( + other_file_list=['**/events.out.tfevents*', '**/*log*']) + cfg_dict['oss_io_config'] = dict( + ak_id='your oss ak id', + ak_secret='your oss ak secret', + hosts='oss-cn-zhangjiakou.aliyuncs.com', + buckets=['your_bucket_2']) + return cfg_dict + + +def init_path(ori_filename): + easycv_root = osp.dirname(easycv.__file__) # easycv package root path + parse_ori_filename = ori_filename.split('/') + if parse_ori_filename[0] == 'configs' or parse_ori_filename[ + 0] == 'benchmarks': + if osp.exists(osp.join(easycv_root, ori_filename)): + ori_filename = osp.join(easycv_root, ori_filename) + + return ori_filename, easycv_root + + # gen mmcv.Config def mmcv_config_fromfile(ori_filename): + ori_filename, easycv_root = init_path(ori_filename) - cfg_dict, cfg_text = mmcv_file2dict_base(ori_filename) + cfg_dict, cfg_text = mmcv_file2dict_base( + ori_filename, easycv_root=easycv_root) if cfg_dict.get('custom_imports', None): import_modules_from_strings(**cfg_dict['custom_imports']) @@ -154,6 +315,46 @@ def mmcv_config_fromfile(ori_filename): return Config(cfg_dict, cfg_text=cfg_text, filename=ori_filename) +def pai_config_fromfile(ori_filename, + user_config_params=None, + model_type=None): + ori_filename, easycv_root = init_path(ori_filename) + + if user_config_params is not None: + # set class_list + class_list_params = None + if 'class_list' in user_config_params: + class_list = user_config_params.pop('class_list') + for key, value in user_config_params.items(): + if 'num_classes' in key: + class_list_params = [class_list, value] + break + + # grouping params + first_order_params, multi_order_params = grouping_params( + user_config_params) + else: + class_list_params, first_order_params, multi_order_params = None, None, None + + # replace first-order parameters + cfg_dict, cfg_text = mmcv_file2dict_base( + ori_filename, first_order_params, easycv_root=easycv_root) + + # Add export and oss ​​related configuration to adapt to pai platform + if model_type: + cfg_dict = adapt_pai_params(cfg_dict, class_list_params) + + if cfg_dict.get('custom_imports', None): + import_modules_from_strings(**cfg_dict['custom_imports']) + + cfg = Config(cfg_dict, cfg_text=cfg_text, filename=ori_filename) + + # replace multi-order parameters + if multi_order_params: + cfg.merge_from_dict(multi_order_params) + return cfg + + # get the true value for ori_key in cfg_dict def get_config_class_value(cfg_dict, ori_key, dict_mem_helper): if ori_key in dict_mem_helper: @@ -320,21 +521,27 @@ def validate_export_config(cfg): CONFIG_TEMPLATE_ZOO = { - # detection - 'YOLOX': 'configs/config_templates/yolox.py', - 'YOLOX_ITAG': 'configs/config_templates/yolox_itag.py', - # cls - 'CLASSIFICATION': 'configs/config_templates/classification.py', - 'CLASSIFICATION_OSS': 'configs/config_templates/classification_oss.py', - 'CLASSIFICATION_TFRECORD_OSS': - 'configs/config_templates/classification_tfrecord_oss.py', + 'CLASSIFICATION_RESNET': + 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py', + 'CLASSIFICATION_RESNEXT': + 'configs/classification/imagenet/resnext/imagenet_resnext50-32x4d_jpg.py', + 'CLASSIFICATION_HRNET': + 'configs/classification/imagenet/hrnet/imagenet_hrnetw18_jpg.py', + 'CLASSIFICATION_VIT': + 'configs/classification/imagenet/vit/imagenet_vit_base_patch16_224_jpg.py', + 'CLASSIFICATION_SWINT': + 'configs/classification/imagenet/swint/imagenet_swin_tiny_patch4_window7_224_jpg.py', # metric learning - 'METRICLEARNING_TFRECORD_OSS': - 'configs/config_templates/metric_learning/softmaxbased_tfrecord_oss.py', + 'METRICLEARNING': + 'configs/metric_learning/imagenet_timm_softmaxbased_jpg.py', 'MODELPARALLEL_METRICLEARNING': - 'configs/config_templates/metric_learning/modelparallel_softmaxbased_tfrecord_oss.py', + 'configs/metric_learning/imagenet_timm_modelparallel_softmaxbased_jpg.py', + + # detection + 'YOLOX': 'configs/config_templates/yolox.py', + 'YOLOX_ITAG': 'configs/config_templates/yolox_itag.py', # ssl 'MOCO_R50_TFRECORD': 'configs/config_templates/moco_r50_tfrecord.py', diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 0bd25eee..9d8d4fd8 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -22,7 +22,7 @@ scikit-image sklearn tensorboard thop -timm>=0.4.9 +timm==0.5.4 wget xtcocotools yacs diff --git a/setup.py b/setup.py index 90f589d7..583dd776 100644 --- a/setup.py +++ b/setup.py @@ -154,6 +154,7 @@ def pack_resource(): proj_dir = root_dir + 'easycv/' shutil.copytree('./easycv', proj_dir) shutil.copytree('./configs', proj_dir + 'configs') + shutil.copytree('./benchmarks', proj_dir + 'benchmarks') shutil.copytree('./tools', proj_dir + 'tools') shutil.copytree('./resource', proj_dir + 'resource') shutil.copytree('./requirements', 'package/requirements') @@ -177,7 +178,8 @@ def pack_resource(): author_email='easycv@list.alibaba-inc.com', keywords='self-supvervised, classification, vision', url='https://github.com/alibaba/EasyCV.git', - packages=find_packages(exclude=('configs', 'tools', 'demo')), + packages=find_packages( + exclude=('configs', 'benchmarks', 'tools', 'demo')), include_package_data=True, classifiers=[ 'Development Status :: 4 - Beta', diff --git a/tests/configs/test_adapt_pai_params.py b/tests/configs/test_adapt_pai_params.py new file mode 100644 index 00000000..64058d23 --- /dev/null +++ b/tests/configs/test_adapt_pai_params.py @@ -0,0 +1,65 @@ +import os.path as osp +import unittest + +from tests.ut_config import CLASS_LIST_TEST + +from easycv.utils.config_tools import adapt_pai_params + + +class AdaptPaiParamsTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + def test_adapt_pai_params_0(self): + cfg_dict = {} + class_list_params = ['', 8] + cfg_dict = adapt_pai_params( + cfg_dict, class_list_params=class_list_params) + + self.assertEqual(cfg_dict['class_list'], + ['0', '1', '2', '3', '4', '5', '6', '7']) + + def test_adapt_pai_params_1(self): + cfg_dict = {} + class_list_params = [['person', 'cat', 'dog'], 8] + cfg_dict = adapt_pai_params( + cfg_dict, class_list_params=class_list_params) + + self.assertEqual(cfg_dict['class_list'], ['person', 'cat', 'dog']) + + def test_adapt_pai_params_2(self): + cfg_dict = {} + class_list_params = [[0, 1, 2], 8] + cfg_dict = adapt_pai_params( + cfg_dict, class_list_params=class_list_params) + + self.assertEqual(cfg_dict['class_list'], ['0', '1', '2']) + + def test_adapt_pai_params_3(self): + cfg_dict = {} + class_list_params = [ + osp.join(CLASS_LIST_TEST, 'class_list_int_test.txt'), 8 + ] + cfg_dict = adapt_pai_params( + cfg_dict, class_list_params=class_list_params) + + self.assertEqual(cfg_dict['class_list'], + ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']) + + def test_adapt_pai_params_4(self): + cfg_dict = {} + class_list_params = [ + osp.join(CLASS_LIST_TEST, 'class_list_str_test.txt'), 8 + ] + cfg_dict = adapt_pai_params( + cfg_dict, class_list_params=class_list_params) + + self.assertEqual(cfg_dict['class_list'], [ + 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', + 'horse', 'ship', 'truck' + ]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/configs/test_check_base_cfg_path.py b/tests/configs/test_check_base_cfg_path.py new file mode 100644 index 00000000..81190c0c --- /dev/null +++ b/tests/configs/test_check_base_cfg_path.py @@ -0,0 +1,83 @@ +import unittest + +from easycv.utils.config_tools import check_base_cfg_path + + +class CheckPathTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + def test_check_0(self): + base_cfg_name = 'configs/base.py' + easycv_root = '/root/easycv' + father_cfg_name = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py' + base_cfg_name = check_base_cfg_path( + base_cfg_name=base_cfg_name, + father_cfg_name=father_cfg_name, + easycv_root=easycv_root) + + self.assertEqual(base_cfg_name, '/root/easycv/configs/base.py') + + def test_check_1(self): + base_cfg_name = 'benchmarks/base.py' + easycv_root = '/root/easycv' + father_cfg_name = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py' + base_cfg_name = check_base_cfg_path( + base_cfg_name=base_cfg_name, + father_cfg_name=father_cfg_name, + easycv_root=easycv_root) + + self.assertEqual(base_cfg_name, '/root/easycv/benchmarks/base.py') + + def test_check_2(self): + base_cfg_name = '../base.py' + easycv_root = '/root/easycv' + father_cfg_name = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py' + base_cfg_name = check_base_cfg_path( + base_cfg_name=base_cfg_name, + father_cfg_name=father_cfg_name, + easycv_root=easycv_root) + + self.assertEqual(base_cfg_name, + 'configs/classification/imagenet/base.py') + + def test_check_3(self): + base_cfg_name = 'common/base.py' + easycv_root = '/root/easycv' + father_cfg_name = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py' + base_cfg_name = check_base_cfg_path( + base_cfg_name=base_cfg_name, + father_cfg_name=father_cfg_name, + easycv_root=easycv_root) + + self.assertEqual( + base_cfg_name, + 'configs/classification/imagenet/resnet/common/base.py') + + def test_check_4(self): + base_cfg_name = 'common/base.py' + easycv_root = '/root/easycv' + father_cfg_name = 'data/classification/imagenet/resnet/imagenet_resnet50_jpg.py' + base_cfg_name = check_base_cfg_path( + base_cfg_name=base_cfg_name, + father_cfg_name=father_cfg_name, + easycv_root=easycv_root) + + self.assertEqual(base_cfg_name, + 'data/classification/imagenet/resnet/common/base.py') + + def test_check_5(self): + base_cfg_name = '../base.py' + easycv_root = '/root/easycv' + father_cfg_name = 'data/classification/imagenet/resnet/imagenet_resnet50_jpg.py' + base_cfg_name = check_base_cfg_path( + base_cfg_name=base_cfg_name, + father_cfg_name=father_cfg_name, + easycv_root=easycv_root) + + self.assertEqual(base_cfg_name, 'data/classification/imagenet/base.py') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/configs/test_template.py b/tests/configs/test_template.py deleted file mode 100644 index 2153d342..00000000 --- a/tests/configs/test_template.py +++ /dev/null @@ -1 +0,0 @@ -# TODO: add template config unittest diff --git a/tests/hooks/test_export_hook.py b/tests/hooks/test_export_hook.py index 93dd4781..958344d1 100644 --- a/tests/hooks/test_export_hook.py +++ b/tests/hooks/test_export_hook.py @@ -64,7 +64,7 @@ def _build_model(): @MODELS.register_module() class TestExportModel(nn.Module): - def __init__(self): + def __init__(self, pretrained=False): super().__init__() self.linear = nn.Linear(2, 1) diff --git a/tests/tools/test_classification_train.py b/tests/tools/test_classification_train.py index 2d94827e..e04739e0 100644 --- a/tests/tools/test_classification_train.py +++ b/tests/tools/test_classification_train.py @@ -6,10 +6,10 @@ import tempfile import unittest -from mmcv import Config from tests.ut_config import SMALL_IMAGENET_RAW_LOCAL from easycv.file import io +from easycv.utils.config_tools import mmcv_config_fromfile, pai_config_fromfile from easycv.utils.test_util import run_in_subprocess sys.path.append(os.path.dirname(os.path.realpath(__file__))) @@ -51,6 +51,21 @@ SMALL_IMAGENET_DATA_ROOT + 'meta/val_labeled_100.txt', 'model.train_preprocess': ['randomErasing', 'mixUp'] } +}, { + 'config_file': + 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py', + 'cfg_options': { + **_COMMON_OPTIONS, 'data_train_root': + SMALL_IMAGENET_DATA_ROOT + 'train/', + 'data_train_list': + SMALL_IMAGENET_DATA_ROOT + 'meta/train_labeled_200.txt', + 'data_test_root': SMALL_IMAGENET_DATA_ROOT + 'validation/', + 'data_test_list': + SMALL_IMAGENET_DATA_ROOT + 'meta/val_labeled_100.txt', + 'image_resize2': [224, 224], + 'save_epochs': 1, + 'eval_epochs': 1 + } }] @@ -62,17 +77,22 @@ def setUp(self): def tearDown(self): super().tearDown() - def _base_train(self, train_cfgs): + def _base_train(self, train_cfgs, adapt_pai=False): cfg_file = train_cfgs.pop('config_file') cfg_options = train_cfgs.pop('cfg_options', None) work_dir = train_cfgs.pop('work_dir', None) if not work_dir: work_dir = tempfile.TemporaryDirectory().name - cfg = Config.fromfile(cfg_file) - if cfg_options is not None: - cfg.merge_from_dict(cfg_options) + if adapt_pai: + cfg = pai_config_fromfile(cfg_file, user_config_params=cfg_options) cfg.eval_pipelines[0].data = cfg.data.val + else: + cfg = mmcv_config_fromfile(cfg_file) + if cfg_options is not None: + cfg.merge_from_dict(cfg_options) + cfg.eval_pipelines[0].data = cfg.data.val + tmp_cfg_file = tempfile.NamedTemporaryFile(suffix='.py').name cfg.dump(tmp_cfg_file) @@ -100,6 +120,11 @@ def test_classification_mixup(self): self._base_train(train_cfgs) + def test_classification_pai(self): + train_cfgs = copy.deepcopy(TRAIN_CONFIGS[2]) + + self._base_train(train_cfgs, adapt_pai=True) + if __name__ == '__main__': unittest.main() diff --git a/tests/tools/test_mae_train.py b/tests/tools/test_mae_train.py index ffdafb60..c3e1b956 100644 --- a/tests/tools/test_mae_train.py +++ b/tests/tools/test_mae_train.py @@ -76,7 +76,6 @@ def _base_train(self, train_cfgs): if not work_dir: work_dir = tempfile.TemporaryDirectory().name - # cfg = Config.fromfile(cfg_file) cfg = mmcv_config_fromfile(cfg_file) if cfg_options is not None: cfg.merge_from_dict(cfg_options) diff --git a/tests/ut_config.py b/tests/ut_config.py index d5d76d90..340d6bd5 100644 --- a/tests/ut_config.py +++ b/tests/ut_config.py @@ -42,6 +42,8 @@ CIFAR100_LOCAL = os.path.join(BASE_LOCAL_PATH, 'data/classification/cifar100') SAMLL_IMAGENET1K_RAW_LOCAL = os.path.join(BASE_LOCAL_PATH, 'datasets/imagenet-1k/imagenet_raw') +CLASS_LIST_TEST = os.path.join(BASE_LOCAL_PATH, + 'data/classification/class_list_test') SMALL_IMAGENET_TFRECORD_LOCAL = os.path.join( BASE_LOCAL_PATH, 'data/classification/small_imagenet_tfrecord/') diff --git a/tools/train.py b/tools/train.py index fab34e87..2b44c7c4 100644 --- a/tools/train.py +++ b/tools/train.py @@ -25,6 +25,7 @@ import torch import torch.distributed as dist from mmcv.runner import init_dist +from mmcv import DictAction from easycv import __version__ from easycv.apis import init_random_seed, set_random_seed, train_model @@ -35,9 +36,9 @@ from easycv.utils.collect_env import collect_env from easycv.utils.logger import get_root_logger from easycv.utils import mmlab_utils -from easycv.utils.config_tools import traverse_replace -from easycv.utils.config_tools import (CONFIG_TEMPLATE_ZOO, - mmcv_config_fromfile, rebuild_config) +from easycv.utils.config_tools import (traverse_replace, CONFIG_TEMPLATE_ZOO, + mmcv_config_fromfile, + pai_config_fromfile) from easycv.utils.dist_utils import get_device, is_master from easycv.utils.setup_env import setup_multi_processes @@ -93,9 +94,14 @@ def parse_args(): ) parser.add_argument( '--user_config_params', - nargs=argparse.REMAINDER, - default=None, - help='modify config options using the command-line') + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed. Single quote double quote equivalent.') args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: @@ -127,12 +133,13 @@ def main(): pass args.config = tpath - cfg = mmcv_config_fromfile(args.config) - if args.user_config_params is not None: - # assert args.model_type is not None, 'model_type must be setted' - # rebuild config by user config params - cfg = rebuild_config(cfg, args.user_config_params) + # build cfg + if args.user_config_params is None: + cfg = mmcv_config_fromfile(args.config) + else: + cfg = pai_config_fromfile(args.config, args.user_config_params, + args.model_type) # set multi-process settings setup_multi_processes(cfg) @@ -275,7 +282,8 @@ def main(): pin_memory=getattr(cfg.data, 'pin_memory', False), replace=getattr(cfg.data, 'sampling_replace', False), seed=cfg.seed, - drop_last=getattr(cfg.data, 'drop_last', False), + # The default should be set to True, because sometimes the last batch is not sampled enough, causing an error in batchnorm + drop_last=getattr(cfg.data, 'drop_last', True), reuse_worker_cache=cfg.data.get('reuse_worker_cache', False), persistent_workers=cfg.data.get('persistent_workers', False), collate_hooks=cfg.data.get('train_collate_hooks', []),