From d5d20f4be6ec4a220b6e094b7ee67e13216316bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Haian=20Huang=28=E6=B7=B1=E5=BA=A6=E7=9C=B8=29?= <1286304229@qq.com> Date: Wed, 24 Nov 2021 15:10:59 +0800 Subject: [PATCH] Support to collect the best models (#6560) * Fix mosaic repr typo (#6523) * Include mmflow in readme (#6545) * Include mmflow in readme * Include mmflow in README_zh-CN * Add mmflow url into the document menu in docs/conf.py and docs_zh-CN/conf.py. * Make OHEM work with seesaw loss (#6514) * update * support gather best model Co-authored-by: Kyungmin Lee <30465912+lkm2835@users.noreply.github.com> Co-authored-by: Czm369 <40661020+Czm369@users.noreply.github.com> Co-authored-by: ohwi --- .dev_scripts/gather_models.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/.dev_scripts/gather_models.py b/.dev_scripts/gather_models.py index 9908f068922..7404ae4639b 100644 --- a/.dev_scripts/gather_models.py +++ b/.dev_scripts/gather_models.py @@ -53,6 +53,14 @@ def get_final_epoch(config): return cfg.runner.max_epochs +def get_best_epoch(exp_dir): + best_epoch_full_path = list( + sorted(glob.glob(osp.join(exp_dir, 'best_*.pth'))))[-1] + best_epoch_model_path = best_epoch_full_path.split('/')[-1] + best_epoch = best_epoch_model_path.split('_')[-1].split('.')[0] + return best_epoch_model_path, int(best_epoch) + + def get_real_epoch(config): cfg = mmcv.Config.fromfile('./configs/' + config) epoch = cfg.runner.max_epochs @@ -160,6 +168,10 @@ def parse_args(): help='root path of benchmarked models to be gathered') parser.add_argument( 'out', type=str, help='output path of gathered models to be stored') + parser.add_argument( + '--best', + action='store_true', + help='whether to gather the best model.') args = parser.parse_args() return args @@ -187,10 +199,13 @@ def main(): for used_config in used_configs: exp_dir = osp.join(models_root, used_config) # check whether the exps is finished - final_epoch = get_final_epoch(used_config) - final_model = 'epoch_{}.pth'.format(final_epoch) - model_path = osp.join(exp_dir, final_model) + if args.best is True: + final_model, final_epoch = get_best_epoch(exp_dir) + else: + final_epoch = get_final_epoch(used_config) + final_model = 'epoch_{}.pth'.format(final_epoch) + model_path = osp.join(exp_dir, final_model) # skip if the model is still training if not osp.exists(model_path): continue @@ -221,6 +236,7 @@ def main(): results=model_performance, epochs=final_epoch, model_time=model_time, + final_model=final_model, log_json_path=osp.split(log_json_path)[-1])) # publish model for each checkpoint @@ -234,7 +250,7 @@ def main(): model_name += '_' + model['model_time'] publish_model_path = osp.join(model_publish_dir, model_name) trained_model_path = osp.join(models_root, model['config'], - 'epoch_{}.pth'.format(model['epochs'])) + model['final_model']) # convert model final_model_path = process_checkpoint(trained_model_path, @@ -254,9 +270,9 @@ def main(): config_path = osp.join( 'configs', config_path) if 'configs' not in config_path else config_path - target_cconfig_path = osp.split(config_path)[-1] - shutil.copy(config_path, - osp.join(model_publish_dir, target_cconfig_path)) + target_config_path = osp.split(config_path)[-1] + shutil.copy(config_path, osp.join(model_publish_dir, + target_config_path)) model['model_path'] = final_model_path publish_model_infos.append(model)