Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] .dev Python files updated to get better performance and syntax #2020

Merged
merged 11 commits into from
Sep 14, 2022
10 changes: 4 additions & 6 deletions .dev/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ def parse_args():
'-s', '--show', action='store_true', help='show results')
parser.add_argument(
'-d', '--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
return args
return parser.parse_args()


def inference_model(config_name, checkpoint, args, logger=None):
Expand All @@ -66,11 +65,10 @@ def inference_model(config_name, checkpoint, args, logger=None):
0.5, 0.75, 1.0, 1.25, 1.5, 1.75
]
cfg.data.test.pipeline[1].flip = True
elif logger is None:
print(f'{config_name}: unable to start aug test', flush=True)
else:
if logger is not None:
logger.error(f'{config_name}: unable to start aug test')
else:
print(f'{config_name}: unable to start aug test', flush=True)
logger.error(f'{config_name}: unable to start aug test')

model = init_segmentor(cfg, checkpoint, device=args.device)
# test a single image
Expand Down
8 changes: 2 additions & 6 deletions .dev/check_urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,9 @@ def check_url(url):
Returns:
int, bool: status code and check flag.
"""
flag = True
r = requests.head(url)
status_code = r.status_code
if status_code == 403 or status_code == 404:
flag = False

flag = status_code not in [403, 404]
return status_code, flag


Expand All @@ -35,8 +32,7 @@ def parse_args():
type=str,
help='Select the model needed to check')

args = parser.parse_args()
return args
return parser.parse_args()


def main():
Expand Down
4 changes: 2 additions & 2 deletions .dev/gather_benchmark_evaluation_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def parse_args():
continue

# Compare between new benchmark results and previous metrics
differential_results = dict()
new_metrics = dict()
differential_results = {}
new_metrics = {}
for record_metric_key in previous_metrics:
if record_metric_key not in metric['metric']:
raise KeyError('record_metric_key not exist, please '
Expand Down
6 changes: 3 additions & 3 deletions .dev/gather_benchmark_train_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def parse_args():
print(f'log file error: {log_json_path}')
continue

differential_results = dict()
Nourollah marked this conversation as resolved.
Show resolved Hide resolved
old_results = dict()
new_results = dict()
differential_results = {}
Nourollah marked this conversation as resolved.
Show resolved Hide resolved
old_results = {}
new_results = {}
for metric_key in model_performance:
if metric_key in ['mIoU']:
metric = round(model_performance[metric_key] * 100, 2)
Expand Down
23 changes: 10 additions & 13 deletions .dev/gather_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def process_checkpoint(in_file, out_file):
# The hash code calculation and rename command differ on different system
# platform.
sha = calculate_file_sha256(out_file)
final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8])
final_file = out_file.rstrip('.pth') + f'-{sha[:8]}.pth'
os.rename(out_file, final_file)

# Remove prefix and suffix
Expand All @@ -50,25 +50,21 @@ def get_final_iter(config):


def get_final_results(log_json_path, iter_num):
result_dict = dict()
result_dict = {}
last_iter = 0
with open(log_json_path, 'r') as f:
for line in f.readlines():
for line in f:
log_line = json.loads(line)
if 'mode' not in log_line.keys():
continue

# When evaluation, the 'iter' of new log json is the evaluation
# steps on single gpu.
flag1 = ('aAcc' in log_line) or (log_line['mode'] == 'val')
flag2 = (last_iter == iter_num - 50) or (last_iter == iter_num)
flag1 = 'aAcc' in log_line or log_line['mode'] == 'val'
MeowZheng marked this conversation as resolved.
Show resolved Hide resolved
flag2 = last_iter in [iter_num - 50, iter_num]
if flag1 and flag2:
result_dict.update({
key: log_line[key]
for key in RESULTS_LUT if key in log_line
})
return result_dict

last_iter = log_line['iter']


Expand Down Expand Up @@ -123,7 +119,7 @@ def main():
exp_dir = osp.join(work_dir, config_name)
# check whether the exps is finished
final_iter = get_final_iter(used_config)
final_model = 'iter_{}.pth'.format(final_iter)
final_model = f'iter_{final_iter}.pth'
model_path = osp.join(exp_dir, final_model)

# skip if the model is still training
Expand All @@ -135,7 +131,7 @@ def main():
log_json_paths = glob.glob(osp.join(exp_dir, '*.log.json'))
log_json_path = log_json_paths[0]
model_performance = None
for idx, _log_json_path in enumerate(log_json_paths):
for _log_json_path in log_json_paths:
model_performance = get_final_results(_log_json_path, final_iter)
if model_performance is not None:
log_json_path = _log_json_path
Expand All @@ -161,9 +157,10 @@ def main():
model_publish_dir = osp.join(collect_dir, config_name)

publish_model_path = osp.join(model_publish_dir,
config_name + '_' + model['model_time'])
f'{config_name}_' + model['model_time'])

trained_model_path = osp.join(work_dir, config_name,
'iter_{}.pth'.format(model['iters']))
f'iter_{model["iters"]}.pth')
if osp.exists(model_publish_dir):
for file in os.listdir(model_publish_dir):
if file.endswith('.pth'):
Expand Down
10 changes: 4 additions & 6 deletions .dev/generate_benchmark_evaluation_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ def parse_args():
default='.dev/benchmark_evaluation.sh',
help='path to save model benchmark script')

args = parser.parse_args()
return args
return parser.parse_args()


def process_model_info(model_info, work_dir):
Expand All @@ -30,10 +29,9 @@ def process_model_info(model_info, work_dir):
job_name = fname
checkpoint = model_info['checkpoint'].strip()
work_dir = osp.join(work_dir, fname)
if not isinstance(model_info['eval'], list):
evals = [model_info['eval']]
else:
evals = model_info['eval']
evals = model_info['eval'] if isinstance(model_info['eval'],
list) else [model_info['eval']]

eval = ' '.join(evals)
return dict(
config=config,
Expand Down
7 changes: 2 additions & 5 deletions .dev/generate_benchmark_train_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,11 @@ def main():
port = args.port
partition_name = 'PARTITION=$1'

commands = []
commands.append(partition_name)
commands.append('\n')
commands.append('\n')
commands = [partition_name, '\n', '\n']

with open(args.txt_path, 'r') as f:
model_cfgs = f.readlines()
for i, cfg in enumerate(model_cfgs):
for cfg in model_cfgs:
create_train_bash_info(commands, cfg, script_name, '$PARTITION',
port)
port += 1
Expand Down
8 changes: 2 additions & 6 deletions .dev/log_collector/log_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,11 @@
def parse_args():
parser = argparse.ArgumentParser(description='extract info from log.json')
parser.add_argument('config_dir')
args = parser.parse_args()
return args
return parser.parse_args()


def has_keyword(name: str, keywords: list):
for a_keyword in keywords:
if a_keyword in name:
return True
return False
return any(a_keyword in name for a_keyword in keywords)


def main():
Expand Down
3 changes: 1 addition & 2 deletions .dev/upload_modelzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ def parse_args():
type=str,
default='mmsegmentation/v0.5',
help='destination folder')
args = parser.parse_args()
return args
return parser.parse_args()


def main():
Expand Down
6 changes: 5 additions & 1 deletion docs/en/tutorials/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,13 @@ log_config = dict( # config to register logger hook
hooks=[
dict(type='TextLoggerHook', by_epoch=False),
dict(type='TensorboardLoggerHook', by_epoch=False),
dict(type='MMSegWandbHook', by_epoch=False, init_kwargs={'entity': entity, 'project': project, 'config': cfg_dict}), # The Wandb logger is also supported, It requires `wandb` to be installed.
dict(type='MMSegWandbHook', by_epoch=False, # The Wandb logger is also supported, It requires `wandb` to be installed.
init_kwargs={'entity': "OpenMMLab", # The entity used to log on Wandb
'project': "MMSeg", # Project name in WandB
'config': cfg_dict}), # Check https://docs.wandb.ai/ref/python/init for more init arguments.
# MMSegWandbHook is mmseg implementation of WandbLoggerHook. ClearMLLoggerHook, DvcliveLoggerHook, MlflowLoggerHook, NeptuneLoggerHook, PaviLoggerHook, SegmindLoggerHook are also supported based on MMCV implementation.
])

dist_params = dict(backend='nccl') # Parameters to setup distributed training, the port can also be set.
log_level = 'INFO' # The level of logging.
load_from = None # load models as a pre-trained model from a given path. This will not resume training.
Expand Down
7 changes: 5 additions & 2 deletions docs/zh_cn/tutorials/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,13 @@ data = dict(
]))
log_config = dict( # 注册日志钩 (register logger hook) 的配置文件。
interval=50, # 打印日志的间隔
hooks=[
hooks=[ # 训练期间执行的钩子
dict(type='TextLoggerHook', by_epoch=False),
dict(type='TensorboardLoggerHook', by_epoch=False),
dict(type='MMSegWandbHook', by_epoch=False, init_kwargs={'entity': entity, 'project': project, 'config': cfg_dict}), # 同样支持 Wandb 日志
dict(type='MMSegWandbHook', by_epoch=False, # 还支持 Wandb 记录器,它需要安装 `wandb`。
init_kwargs={'entity': "OpenMMLab", # 用于登录wandb的实体
'project': "mmseg", # WandB中的项目名称
'config': cfg_dict}), # 检查 https://docs.wandb.ai/ref/python/init 以获取更多初始化参数
])

dist_params = dict(backend='nccl') # 用于设置分布式训练的参数,端口也同样可被设置。
Expand Down
71 changes: 30 additions & 41 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ def parse_line(line):
if line.startswith('-r '):
# Allow specifying requirements in other files
target = line.split(' ')[1]
for info in parse_require_file(target):
yield info
yield from parse_require_file(target)
else:
info = {'line': line}
if line.startswith('-e '):
Expand All @@ -58,7 +57,6 @@ def parse_line(line):
pat = '(' + '|'.join(['>=', '==', '>']) + ')'
parts = re.split(pat, line, maxsplit=1)
parts = [p.strip() for p in parts]

info['package'] = parts[0]
if len(parts) > 1:
op, rest = parts[1:]
Expand All @@ -69,31 +67,30 @@ def parse_line(line):
rest.split(';'))
info['platform_deps'] = platform_deps
else:
version = rest # NOQA
info['version'] = (op, version)
version = rest
info['version'] = op, version
yield info

def parse_require_file(fpath):
with open(fpath, 'r') as f:
for line in f.readlines():
line = line.strip()
if line and not line.startswith('#'):
for info in parse_line(line):
yield info
yield from parse_line(line)

def gen_packages_items():
if exists(require_fpath):
for info in parse_require_file(require_fpath):
parts = [info['package']]
if with_version and 'version' in info:
parts.extend(info['version'])
if not sys.version.startswith('3.4'):
# apparently package_deps are broken in 3.4
platform_deps = info.get('platform_deps')
if platform_deps is not None:
parts.append(';' + platform_deps)
item = ''.join(parts)
yield item
if not exists(require_fpath):
return
for info in parse_require_file(require_fpath):
parts = [info['package']]
if with_version and 'version' in info:
parts.extend(info['version'])
if not sys.version.startswith('3.4'):
platform_deps = info.get('platform_deps')
if platform_deps is not None:
parts.append(f';{platform_deps}')
item = ''.join(parts)
yield item

packages = list(gen_packages_items())
return packages
Expand All @@ -110,35 +107,28 @@ def add_mim_extension():
# parse installment mode
if 'develop' in sys.argv:
# installed by `pip install -e .`
if platform.system() == 'Windows':
# set `copy` mode here since symlink fails on Windows.
mode = 'copy'
else:
mode = 'symlink'
elif 'sdist' in sys.argv or 'bdist_wheel' in sys.argv or \
platform.system() == 'Windows':
# set `copy` mode here since symlink fails on Windows.
mode = 'copy' if platform.system() == 'Windows' else 'symlink'
elif 'sdist' in sys.argv or 'bdist_wheel' in sys.argv or platform.system(
) == 'Windows':
# installed by `pip install .`
# or create source distribution by `python setup.py sdist`
# set `copy` mode here since symlink fails with WinError on Windows.
mode = 'copy'
else:
return

filenames = ['tools', 'configs', 'model-index.yml']
repo_path = osp.dirname(__file__)
mim_path = osp.join(repo_path, 'mmseg', '.mim')
os.makedirs(mim_path, exist_ok=True)

for filename in filenames:
if osp.exists(filename):
src_path = osp.join(repo_path, filename)
tar_path = osp.join(mim_path, filename)

if osp.isfile(tar_path) or osp.islink(tar_path):
os.remove(tar_path)
elif osp.isdir(tar_path):
shutil.rmtree(tar_path)

if mode == 'symlink':
src_relpath = osp.relpath(src_path, osp.dirname(tar_path))
try:
Expand All @@ -149,20 +139,19 @@ def add_mim_extension():
# the error happens, the src file will be copied
mode = 'copy'
warnings.warn(
f'Failed to create a symbolic link for {src_relpath}, '
f'and it will be copied to {tar_path}')
else:
continue
f'Failed to create a symbolic link for {src_relpath},'
f' and it will be copied to {tar_path}')

if mode == 'copy':
if osp.isfile(src_path):
shutil.copyfile(src_path, tar_path)
elif osp.isdir(src_path):
shutil.copytree(src_path, tar_path)
else:
warnings.warn(f'Cannot copy file {src_path}.')
else:
continue
if mode != 'copy':
raise ValueError(f'Invalid mode {mode}')
if osp.isfile(src_path):
shutil.copyfile(src_path, tar_path)
elif osp.isdir(src_path):
shutil.copytree(src_path, tar_path)
else:
warnings.warn(f'Cannot copy file {src_path}.')


if __name__ == '__main__':
Expand Down