Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support variables in base files for configs #1083

Merged
merged 11 commits into from
Jun 25, 2021
Merged
54 changes: 54 additions & 0 deletions mmcv/utils/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) Open-MMLab. All rights reserved.
import ast
import copy
import os.path as osp
import platform
import shutil
import sys
import tempfile
import uuid
import warnings
from argparse import Action, ArgumentParser
from collections import abc
Expand Down Expand Up @@ -120,6 +122,51 @@ def _substitute_predefined_vars(filename, temp_config_name):
with open(temp_config_name, 'w') as tmp_config_file:
tmp_config_file.write(config_file)

@staticmethod
def _pre_substitute_base_vars(filename, temp_config_name):
"""Substitute base variable placehoders to string, so that parsing
would work."""
with open(filename, 'r', encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
config_file = f.read()
base_var_dict = {}
regexp = r'\{\{\s*base\.([\w\.]+)\s*\}\}'
base_vars = set(re.findall(regexp, config_file))
for base_var in base_vars:
randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}'
base_var_dict[randstr] = base_var
innerlee marked this conversation as resolved.
Show resolved Hide resolved
regexp = r'\{\{\s*base\.' + base_var + r'\s*\}\}'
config_file = re.sub(regexp, f'"{randstr}"', config_file)
with open(temp_config_name, 'w') as tmp_config_file:
tmp_config_file.write(config_file)
return base_var_dict

@staticmethod
def _substitute_base_vars(cfg, base_var_dict, base_cfg):
cfg = copy.deepcopy(cfg)

if isinstance(cfg, dict):
for k, v in cfg.items():
if isinstance(v, str) and v in base_var_dict:
new_v = base_cfg
for new_k in base_var_dict[v].split('.'):
new_v = new_v[new_k]
cfg[k] = new_v
elif isinstance(v, (list, tuple, dict)):
cfg[k] = Config._substitute_base_vars(
v, base_var_dict, base_cfg)
elif isinstance(cfg, tuple):
cfg = tuple(
Config._substitute_base_vars(c, base_var_dict, base_cfg)
for c in cfg)
elif isinstance(cfg, list):
cfg = [
Config._substitute_base_vars(c, base_var_dict, base_cfg)
for c in cfg
]

return cfg

@staticmethod
def _file2dict(filename, use_predefined_variables=True):
filename = osp.abspath(osp.expanduser(filename))
Expand All @@ -140,6 +187,9 @@ def _file2dict(filename, use_predefined_variables=True):
temp_config_file.name)
else:
shutil.copyfile(filename, temp_config_file.name)
# Substitute base variables
base_var_dict = Config._pre_substitute_base_vars(
temp_config_file.name, temp_config_file.name)

if filename.endswith('.py'):
temp_module_name = osp.splitext(temp_config_name)[0]
Expand Down Expand Up @@ -184,6 +234,10 @@ def _file2dict(filename, use_predefined_variables=True):
raise KeyError('Duplicate key is not allowed among bases')
base_cfg_dict.update(c)

# Subtitute base variables
cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict,
base_cfg_dict)

base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
cfg_dict = base_cfg_dict

Expand Down
6 changes: 6 additions & 0 deletions tests/data/config/t.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_base_ = ['./l1.py', './l2.yaml', './l3.json', './l4.py']
item3 = False
item4 = 'test'
item8 = '{{fileBasename}}'
item9 = {{ base.item2 }}
item10 = {{ base.item7.b.c }}
18 changes: 18 additions & 0 deletions tests/test_utils/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,24 @@ def test_merge_from_multiple_bases():
Config.fromfile(osp.join(data_path, 'config/m.py'))


def test_base_variables():
cfg_file = osp.join(data_path, 'config/t.py')
cfg = Config.fromfile(cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
# cfg.field
assert cfg.item1 == [1, 2]
assert cfg.item2.a == 0
assert cfg.item3 is False
assert cfg.item4 == 'test'
assert cfg.item5 == dict(a=0, b=1)
assert cfg.item6 == [dict(a=0), dict(b=1)]
assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3]))
assert cfg.item8 == 't.py'
assert cfg.item9 == dict(a=0)
assert cfg.item10 == [3.1, 4.2, 5.3]


def test_merge_recursive_bases():
cfg_file = osp.join(data_path, 'config/f.py')
cfg = Config.fromfile(cfg_file)
Expand Down