diff --git a/.dev/benchmark_inference.py b/.dev/benchmark_inference.py index 5124811036..3ab681bc38 100644 --- a/.dev/benchmark_inference.py +++ b/.dev/benchmark_inference.py @@ -53,8 +53,7 @@ def parse_args(): '-s', '--show', action='store_true', help='show results') parser.add_argument( '-d', '--device', default='cuda:0', help='Device used for inference') - args = parser.parse_args() - return args + return parser.parse_args() def inference_model(config_name, checkpoint, args, logger=None): @@ -66,11 +65,10 @@ def inference_model(config_name, checkpoint, args, logger=None): 0.5, 0.75, 1.0, 1.25, 1.5, 1.75 ] cfg.data.test.pipeline[1].flip = True + elif logger is None: + print(f'{config_name}: unable to start aug test', flush=True) else: - if logger is not None: - logger.error(f'{config_name}: unable to start aug test') - else: - print(f'{config_name}: unable to start aug test', flush=True) + logger.error(f'{config_name}: unable to start aug test') model = init_segmentor(cfg, checkpoint, device=args.device) # test a single image diff --git a/.dev/check_urls.py b/.dev/check_urls.py index 42b64745de..c98d0a153e 100644 --- a/.dev/check_urls.py +++ b/.dev/check_urls.py @@ -18,12 +18,9 @@ def check_url(url): Returns: int, bool: status code and check flag. """ - flag = True r = requests.head(url) status_code = r.status_code - if status_code == 403 or status_code == 404: - flag = False - + flag = status_code not in [403, 404] return status_code, flag @@ -35,8 +32,7 @@ def parse_args(): type=str, help='Select the model needed to check') - args = parser.parse_args() - return args + return parser.parse_args() def main(): diff --git a/.dev/gather_benchmark_evaluation_results.py b/.dev/gather_benchmark_evaluation_results.py index 47b557a105..a8bfb4cca9 100644 --- a/.dev/gather_benchmark_evaluation_results.py +++ b/.dev/gather_benchmark_evaluation_results.py @@ -62,8 +62,8 @@ def parse_args(): continue # Compare between new benchmark results and previous metrics - differential_results = dict() - new_metrics = dict() + differential_results = {} + new_metrics = {} for record_metric_key in previous_metrics: if record_metric_key not in metric['metric']: raise KeyError('record_metric_key not exist, please ' diff --git a/.dev/gather_benchmark_train_results.py b/.dev/gather_benchmark_train_results.py index 8aff2c4228..e729ca205a 100644 --- a/.dev/gather_benchmark_train_results.py +++ b/.dev/gather_benchmark_train_results.py @@ -72,9 +72,9 @@ def parse_args(): print(f'log file error: {log_json_path}') continue - differential_results = dict() - old_results = dict() - new_results = dict() + differential_results = {} + old_results = {} + new_results = {} for metric_key in model_performance: if metric_key in ['mIoU']: metric = round(model_performance[metric_key] * 100, 2) diff --git a/.dev/gather_models.py b/.dev/gather_models.py index 3eedf6110b..6158623d4c 100644 --- a/.dev/gather_models.py +++ b/.dev/gather_models.py @@ -33,7 +33,7 @@ def process_checkpoint(in_file, out_file): # The hash code calculation and rename command differ on different system # platform. sha = calculate_file_sha256(out_file) - final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8]) + final_file = out_file.rstrip('.pth') + f'-{sha[:8]}.pth' os.rename(out_file, final_file) # Remove prefix and suffix @@ -50,25 +50,23 @@ def get_final_iter(config): def get_final_results(log_json_path, iter_num): - result_dict = dict() + result_dict = {} last_iter = 0 with open(log_json_path, 'r') as f: - for line in f.readlines(): + for line in f: log_line = json.loads(line) if 'mode' not in log_line.keys(): continue - # When evaluation, the 'iter' of new log json is the evaluation # steps on single gpu. - flag1 = ('aAcc' in log_line) or (log_line['mode'] == 'val') - flag2 = (last_iter == iter_num - 50) or (last_iter == iter_num) + flag1 = 'aAcc' in log_line or log_line['mode'] == 'val' + flag2 = last_iter in [iter_num - 50, iter_num] if flag1 and flag2: result_dict.update({ key: log_line[key] for key in RESULTS_LUT if key in log_line }) return result_dict - last_iter = log_line['iter'] @@ -123,7 +121,7 @@ def main(): exp_dir = osp.join(work_dir, config_name) # check whether the exps is finished final_iter = get_final_iter(used_config) - final_model = 'iter_{}.pth'.format(final_iter) + final_model = f'iter_{final_iter}.pth' model_path = osp.join(exp_dir, final_model) # skip if the model is still training @@ -135,7 +133,7 @@ def main(): log_json_paths = glob.glob(osp.join(exp_dir, '*.log.json')) log_json_path = log_json_paths[0] model_performance = None - for idx, _log_json_path in enumerate(log_json_paths): + for _log_json_path in log_json_paths: model_performance = get_final_results(_log_json_path, final_iter) if model_performance is not None: log_json_path = _log_json_path @@ -161,9 +159,10 @@ def main(): model_publish_dir = osp.join(collect_dir, config_name) publish_model_path = osp.join(model_publish_dir, - config_name + '_' + model['model_time']) + f'{config_name}_' + model['model_time']) + trained_model_path = osp.join(work_dir, config_name, - 'iter_{}.pth'.format(model['iters'])) + f'iter_{model["iters"]}.pth') if osp.exists(model_publish_dir): for file in os.listdir(model_publish_dir): if file.endswith('.pth'): diff --git a/.dev/generate_benchmark_evaluation_script.py b/.dev/generate_benchmark_evaluation_script.py index d86e94bc8f..fd49f2b5c8 100644 --- a/.dev/generate_benchmark_evaluation_script.py +++ b/.dev/generate_benchmark_evaluation_script.py @@ -20,8 +20,7 @@ def parse_args(): default='.dev/benchmark_evaluation.sh', help='path to save model benchmark script') - args = parser.parse_args() - return args + return parser.parse_args() def process_model_info(model_info, work_dir): @@ -30,10 +29,9 @@ def process_model_info(model_info, work_dir): job_name = fname checkpoint = model_info['checkpoint'].strip() work_dir = osp.join(work_dir, fname) - if not isinstance(model_info['eval'], list): - evals = [model_info['eval']] - else: - evals = model_info['eval'] + evals = model_info['eval'] if isinstance(model_info['eval'], + list) else [model_info['eval']] + eval = ' '.join(evals) return dict( config=config, diff --git a/.dev/generate_benchmark_train_script.py b/.dev/generate_benchmark_train_script.py index 6e8a0ae311..32d0a71f54 100644 --- a/.dev/generate_benchmark_train_script.py +++ b/.dev/generate_benchmark_train_script.py @@ -69,14 +69,11 @@ def main(): port = args.port partition_name = 'PARTITION=$1' - commands = [] - commands.append(partition_name) - commands.append('\n') - commands.append('\n') + commands = [partition_name, '\n', '\n'] with open(args.txt_path, 'r') as f: model_cfgs = f.readlines() - for i, cfg in enumerate(model_cfgs): + for cfg in model_cfgs: create_train_bash_info(commands, cfg, script_name, '$PARTITION', port) port += 1 diff --git a/.dev/log_collector/log_collector.py b/.dev/log_collector/log_collector.py index d0f4080877..cc7b4136c5 100644 --- a/.dev/log_collector/log_collector.py +++ b/.dev/log_collector/log_collector.py @@ -27,15 +27,11 @@ def parse_args(): parser = argparse.ArgumentParser(description='extract info from log.json') parser.add_argument('config_dir') - args = parser.parse_args() - return args + return parser.parse_args() def has_keyword(name: str, keywords: list): - for a_keyword in keywords: - if a_keyword in name: - return True - return False + return any(a_keyword in name for a_keyword in keywords) def main(): diff --git a/.dev/upload_modelzoo.py b/.dev/upload_modelzoo.py index 303c80d2e3..a0b94ffbab 100644 --- a/.dev/upload_modelzoo.py +++ b/.dev/upload_modelzoo.py @@ -19,8 +19,7 @@ def parse_args(): type=str, default='mmsegmentation/v0.5', help='destination folder') - args = parser.parse_args() - return args + return parser.parse_args() def main(): diff --git a/docs/en/tutorials/config.md b/docs/en/tutorials/config.md index cbf3777eb0..293f0ac55f 100644 --- a/docs/en/tutorials/config.md +++ b/docs/en/tutorials/config.md @@ -221,9 +221,13 @@ log_config = dict( # config to register logger hook hooks=[ dict(type='TextLoggerHook', by_epoch=False), dict(type='TensorboardLoggerHook', by_epoch=False), - dict(type='MMSegWandbHook', by_epoch=False, init_kwargs={'entity': entity, 'project': project, 'config': cfg_dict}), # The Wandb logger is also supported, It requires `wandb` to be installed. + dict(type='MMSegWandbHook', by_epoch=False, # The Wandb logger is also supported, It requires `wandb` to be installed. + init_kwargs={'entity': "OpenMMLab", # The entity used to log on Wandb + 'project': "MMSeg", # Project name in WandB + 'config': cfg_dict}), # Check https://docs.wandb.ai/ref/python/init for more init arguments. # MMSegWandbHook is mmseg implementation of WandbLoggerHook. ClearMLLoggerHook, DvcliveLoggerHook, MlflowLoggerHook, NeptuneLoggerHook, PaviLoggerHook, SegmindLoggerHook are also supported based on MMCV implementation. ]) + dist_params = dict(backend='nccl') # Parameters to setup distributed training, the port can also be set. log_level = 'INFO' # The level of logging. load_from = None # load models as a pre-trained model from a given path. This will not resume training. diff --git a/docs/zh_cn/tutorials/config.md b/docs/zh_cn/tutorials/config.md index 8402b138b3..f4741f1fe6 100644 --- a/docs/zh_cn/tutorials/config.md +++ b/docs/zh_cn/tutorials/config.md @@ -214,10 +214,13 @@ data = dict( ])) log_config = dict( # 注册日志钩 (register logger hook) 的配置文件。 interval=50, # 打印日志的间隔 - hooks=[ + hooks=[ # 训练期间执行的钩子 dict(type='TextLoggerHook', by_epoch=False), dict(type='TensorboardLoggerHook', by_epoch=False), - dict(type='MMSegWandbHook', by_epoch=False, init_kwargs={'entity': entity, 'project': project, 'config': cfg_dict}), # 同样支持 Wandb 日志 + dict(type='MMSegWandbHook', by_epoch=False, # 还支持 Wandb 记录器,它需要安装 `wandb`。 + init_kwargs={'entity': "OpenMMLab", # 用于登录wandb的实体 + 'project': "mmseg", # WandB中的项目名称 + 'config': cfg_dict}), # 检查 https://docs.wandb.ai/ref/python/init 以获取更多初始化参数 ]) dist_params = dict(backend='nccl') # 用于设置分布式训练的参数,端口也同样可被设置。 diff --git a/setup.py b/setup.py index ad09e6ce76..7461e76937 100755 --- a/setup.py +++ b/setup.py @@ -47,8 +47,7 @@ def parse_line(line): if line.startswith('-r '): # Allow specifying requirements in other files target = line.split(' ')[1] - for info in parse_require_file(target): - yield info + yield from parse_require_file(target) else: info = {'line': line} if line.startswith('-e '): @@ -58,7 +57,6 @@ def parse_line(line): pat = '(' + '|'.join(['>=', '==', '>']) + ')' parts = re.split(pat, line, maxsplit=1) parts = [p.strip() for p in parts] - info['package'] = parts[0] if len(parts) > 1: op, rest = parts[1:] @@ -69,8 +67,8 @@ def parse_line(line): rest.split(';')) info['platform_deps'] = platform_deps else: - version = rest # NOQA - info['version'] = (op, version) + version = rest + info['version'] = op, version yield info def parse_require_file(fpath): @@ -78,22 +76,21 @@ def parse_require_file(fpath): for line in f.readlines(): line = line.strip() if line and not line.startswith('#'): - for info in parse_line(line): - yield info + yield from parse_line(line) def gen_packages_items(): - if exists(require_fpath): - for info in parse_require_file(require_fpath): - parts = [info['package']] - if with_version and 'version' in info: - parts.extend(info['version']) - if not sys.version.startswith('3.4'): - # apparently package_deps are broken in 3.4 - platform_deps = info.get('platform_deps') - if platform_deps is not None: - parts.append(';' + platform_deps) - item = ''.join(parts) - yield item + if not exists(require_fpath): + return + for info in parse_require_file(require_fpath): + parts = [info['package']] + if with_version and 'version' in info: + parts.extend(info['version']) + if not sys.version.startswith('3.4'): + platform_deps = info.get('platform_deps') + if platform_deps is not None: + parts.append(f';{platform_deps}') + item = ''.join(parts) + yield item packages = list(gen_packages_items()) return packages @@ -110,35 +107,28 @@ def add_mim_extension(): # parse installment mode if 'develop' in sys.argv: # installed by `pip install -e .` - if platform.system() == 'Windows': - # set `copy` mode here since symlink fails on Windows. - mode = 'copy' - else: - mode = 'symlink' - elif 'sdist' in sys.argv or 'bdist_wheel' in sys.argv or \ - platform.system() == 'Windows': + # set `copy` mode here since symlink fails on Windows. + mode = 'copy' if platform.system() == 'Windows' else 'symlink' + elif 'sdist' in sys.argv or 'bdist_wheel' in sys.argv or platform.system( + ) == 'Windows': # installed by `pip install .` # or create source distribution by `python setup.py sdist` # set `copy` mode here since symlink fails with WinError on Windows. mode = 'copy' else: return - filenames = ['tools', 'configs', 'model-index.yml'] repo_path = osp.dirname(__file__) mim_path = osp.join(repo_path, 'mmseg', '.mim') os.makedirs(mim_path, exist_ok=True) - for filename in filenames: if osp.exists(filename): src_path = osp.join(repo_path, filename) tar_path = osp.join(mim_path, filename) - if osp.isfile(tar_path) or osp.islink(tar_path): os.remove(tar_path) elif osp.isdir(tar_path): shutil.rmtree(tar_path) - if mode == 'symlink': src_relpath = osp.relpath(src_path, osp.dirname(tar_path)) try: @@ -149,20 +139,19 @@ def add_mim_extension(): # the error happens, the src file will be copied mode = 'copy' warnings.warn( - f'Failed to create a symbolic link for {src_relpath}, ' - f'and it will be copied to {tar_path}') - else: - continue + f'Failed to create a symbolic link for {src_relpath},' + f' and it will be copied to {tar_path}') - if mode == 'copy': - if osp.isfile(src_path): - shutil.copyfile(src_path, tar_path) - elif osp.isdir(src_path): - shutil.copytree(src_path, tar_path) else: - warnings.warn(f'Cannot copy file {src_path}.') - else: + continue + if mode != 'copy': raise ValueError(f'Invalid mode {mode}') + if osp.isfile(src_path): + shutil.copyfile(src_path, tar_path) + elif osp.isdir(src_path): + shutil.copytree(src_path, tar_path) + else: + warnings.warn(f'Cannot copy file {src_path}.') if __name__ == '__main__':