Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support variables in base files for configs #1083

Merged
merged 11 commits into from
Jun 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions docs/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,32 @@ _base_ = ['./config_a.py', './config_e.py']
... d='string')
```

#### Reference variables from base

You can reference variables defined in base using the following grammar.

`base.py`

```python
item1 = 'a'
item2 = dict(item3 = 'b')
```

`config_g.py`

```python
_base_ = ['./base.py']
item = dict(a = {{ _base_.item1 }}, b = {{ _base_.item2.item3 }})
```

```python
>>> cfg = Config.fromfile('./config_g.py')
>>> print(cfg.pretty_text)
item1 = 'a'
item2 = dict(item3='b')
item = dict(a='a', b='b')
```

### ProgressBar

If you want to apply a method to a list of items and track the progress, `track_progress`
Expand Down
60 changes: 60 additions & 0 deletions mmcv/utils/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) Open-MMLab. All rights reserved.
import ast
import copy
import os
import os.path as osp
import platform
import shutil
import sys
import tempfile
import uuid
import warnings
from argparse import Action, ArgumentParser
from collections import abc
Expand Down Expand Up @@ -121,6 +123,57 @@ def _substitute_predefined_vars(filename, temp_config_name):
with open(temp_config_name, 'w') as tmp_config_file:
tmp_config_file.write(config_file)

@staticmethod
def _pre_substitute_base_vars(filename, temp_config_name):
"""Substitute base variable placehoders to string, so that parsing
would work."""
with open(filename, 'r', encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
config_file = f.read()
base_var_dict = {}
regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}'
base_vars = set(re.findall(regexp, config_file))
for base_var in base_vars:
randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}'
base_var_dict[randstr] = base_var
innerlee marked this conversation as resolved.
Show resolved Hide resolved
regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}'
config_file = re.sub(regexp, f'"{randstr}"', config_file)
with open(temp_config_name, 'w') as tmp_config_file:
tmp_config_file.write(config_file)
return base_var_dict

@staticmethod
def _substitute_base_vars(cfg, base_var_dict, base_cfg):
"""Substitute variable strings to their actual values."""
cfg = copy.deepcopy(cfg)

if isinstance(cfg, dict):
for k, v in cfg.items():
if isinstance(v, str) and v in base_var_dict:
new_v = base_cfg
for new_k in base_var_dict[v].split('.'):
new_v = new_v[new_k]
cfg[k] = new_v
elif isinstance(v, (list, tuple, dict)):
cfg[k] = Config._substitute_base_vars(
v, base_var_dict, base_cfg)
elif isinstance(cfg, tuple):
cfg = tuple(
Config._substitute_base_vars(c, base_var_dict, base_cfg)
for c in cfg)
elif isinstance(cfg, list):
cfg = [
Config._substitute_base_vars(c, base_var_dict, base_cfg)
for c in cfg
]
elif isinstance(cfg, str) and cfg in base_var_dict:
new_v = base_cfg
for new_k in base_var_dict[cfg].split('.'):
new_v = new_v[new_k]
cfg = new_v

return cfg

@staticmethod
def _file2dict(filename, use_predefined_variables=True):
filename = osp.abspath(osp.expanduser(filename))
Expand All @@ -141,6 +194,9 @@ def _file2dict(filename, use_predefined_variables=True):
temp_config_file.name)
else:
shutil.copyfile(filename, temp_config_file.name)
# Substitute base variables from placeholders to strings
base_var_dict = Config._pre_substitute_base_vars(
temp_config_file.name, temp_config_file.name)

if filename.endswith('.py'):
temp_module_name = osp.splitext(temp_config_name)[0]
Expand Down Expand Up @@ -185,6 +241,10 @@ def _file2dict(filename, use_predefined_variables=True):
raise KeyError('Duplicate key is not allowed among bases')
base_cfg_dict.update(c)

# Subtitute base variables from strings to their actual values
cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict,
base_cfg_dict)

base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
cfg_dict = base_cfg_dict

Expand Down
13 changes: 13 additions & 0 deletions tests/data/config/t.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"_base_": [
"./l1.py",
"./l2.yaml",
"./l3.json",
"./l4.py"
],
"item3": false,
"item4": "test",
"item8": "{{fileBasename}}",
"item9": {{ _base_.item2 }},
"item10": {{ _base_.item7.b.c }}
}
6 changes: 6 additions & 0 deletions tests/data/config/t.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_base_ = ['./l1.py', './l2.yaml', './l3.json', './l4.py']
item3 = False
item4 = 'test'
item8 = '{{fileBasename}}'
item9 = {{ _base_.item2 }}
item10 = {{ _base_.item7.b.c }}
6 changes: 6 additions & 0 deletions tests/data/config/t.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_base_ : ['./l1.py', './l2.yaml', './l3.json', './l4.py']
item3 : False
item4 : 'test'
item8 : '{{fileBasename}}'
item9 : {{ _base_.item2 }}
item10 : {{ _base_.item7.b.c }}
26 changes: 26 additions & 0 deletions tests/data/config/u.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"_base_": [
"./t.py"
],
"base": "_base_.item8",
"item11": {{ _base_.item8 }},
"item12": {{ _base_.item9 }},
"item13": {{ _base_.item10 }},
"item14": {{ _base_.item1 }},
"item15": {
"a": {
"b": {{ _base_.item2 }}
},
"b": [
{{ _base_.item3 }}
],
"c": [{{ _base_.item4 }}],
"d": [[
{
"e": {{ _base_.item5.a }}
}
],
{{ _base_.item6 }}],
"e": {{ _base_.item1 }}
}
}
13 changes: 13 additions & 0 deletions tests/data/config/u.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
_base_ = ['./t.py']
base = '_base_.item8'
item11 = {{ _base_.item8 }}
item12 = {{ _base_.item9 }}
item13 = {{ _base_.item10 }}
item14 = {{ _base_.item1 }}
item15 = dict(
a = dict( b = {{ _base_.item2 }} ),
b = [{{ _base_.item3 }}],
c = [{{ _base_.item4 }}],
d = [[dict(e = {{ _base_.item5.a }})],{{ _base_.item6 }}],
e = {{ _base_.item1 }}
)
15 changes: 15 additions & 0 deletions tests/data/config/u.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
_base_: ["./t.py"]
base: "_base_.item8"
item11: {{ _base_.item8 }}
item12: {{ _base_.item9 }}
item13: {{ _base_.item10 }}
item14: {{ _base_.item1 }}
item15:
a:
b: {{ _base_.item2 }}
b: [{{ _base_.item3 }}]
c: [{{ _base_.item4 }}]
d:
- [e: {{ _base_.item5.a }}]
- {{ _base_.item6 }}
e: {{ _base_.item1 }}
11 changes: 11 additions & 0 deletions tests/data/config/v.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_base_ = ['./u.py']
item21 = {{ _base_.item11 }}
item22 = item21
item23 = {{ _base_.item10 }}
item24 = item23
item25 = dict(
a = dict( b = item24 ),
b = [item24],
c = [[dict(e = item22)],{{ _base_.item6 }}],
e = item21
)
75 changes: 75 additions & 0 deletions tests/test_utils/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,81 @@ def test_merge_from_multiple_bases():
Config.fromfile(osp.join(data_path, 'config/m.py'))


def test_base_variables():
for file in ['t.py', 't.json', 't.yaml']:
cfg_file = osp.join(data_path, f'config/{file}')
cfg = Config.fromfile(cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
# cfg.field
assert cfg.item1 == [1, 2]
assert cfg.item2.a == 0
assert cfg.item3 is False
assert cfg.item4 == 'test'
assert cfg.item5 == dict(a=0, b=1)
assert cfg.item6 == [dict(a=0), dict(b=1)]
assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3]))
assert cfg.item8 == file
assert cfg.item9 == dict(a=0)
assert cfg.item10 == [3.1, 4.2, 5.3]

# test nested base
for file in ['u.py', 'u.json', 'u.yaml']:
cfg_file = osp.join(data_path, f'config/{file}')
cfg = Config.fromfile(cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
# cfg.field
assert cfg.base == '_base_.item8'
assert cfg.item1 == [1, 2]
assert cfg.item2.a == 0
assert cfg.item3 is False
assert cfg.item4 == 'test'
assert cfg.item5 == dict(a=0, b=1)
assert cfg.item6 == [dict(a=0), dict(b=1)]
assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3]))
assert cfg.item8 == 't.py'
assert cfg.item9 == dict(a=0)
assert cfg.item10 == [3.1, 4.2, 5.3]
assert cfg.item11 == 't.py'
assert cfg.item12 == dict(a=0)
assert cfg.item13 == [3.1, 4.2, 5.3]
assert cfg.item14 == [1, 2]
assert cfg.item15 == dict(
a=dict(b=dict(a=0)),
b=[False],
c=['test'],
d=[[{
'e': 0
}], [{
'a': 0
}, {
'b': 1
}]],
e=[1, 2])

# test reference assignment for py
cfg_file = osp.join(data_path, 'config/v.py')
cfg = Config.fromfile(cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
assert cfg.item21 == 't.py'
assert cfg.item22 == 't.py'
assert cfg.item23 == [3.1, 4.2, 5.3]
assert cfg.item24 == [3.1, 4.2, 5.3]
assert cfg.item25 == dict(
a=dict(b=[3.1, 4.2, 5.3]),
b=[[3.1, 4.2, 5.3]],
c=[[{
'e': 't.py'
}], [{
'a': 0
}, {
'b': 1
}]],
e='t.py')


def test_merge_recursive_bases():
cfg_file = osp.join(data_path, 'config/f.py')
cfg = Config.fromfile(cfg_file)
Expand Down