Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add ProcessDict associated to @dict #53

Merged
merged 4 commits into from
Dec 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 126 additions & 1 deletion cliconfig/processing/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
and :func:`.make_config`.
"""
import ast
from typing import Any, Dict, List, Set
from typing import Any, Dict, List, Set, Tuple

from cliconfig.base import Config
from cliconfig.dict_routines import unflatten
from cliconfig.process_routines import (
merge_flat_paths_processing,
merge_flat_processing,
Expand All @@ -19,6 +20,8 @@
from cliconfig.processing.base import Processing
from cliconfig.tag_routines import clean_all_tags, clean_tag, dict_clean_tags, is_tag_in

TypeSplitDict = Dict[str, List[Tuple[str, Any]]]


class ProcessMerge(Processing):
"""Merge dicts just in time with ``@merge_after/_before/_add`` tags.
Expand Down Expand Up @@ -773,6 +776,127 @@ def presave(self, flat_config: Config) -> Config:
return flat_config


class ProcessDict(Processing):
"""Declare a dict instead of a sub-config with ``@dict`` tag.

This is a pre-merge only processing that remove the dict and remove
the rest of the key after the tag. It can be used to declare a dict
where the keys and are not known in advance. New keys are allowed
in each merge and the element are still available using the dot
notation like ``config.subconfig.mydict.something``.
The pre-merge processing is applied before all other processings.
Pre-merge order: -30.0

Examples
--------
.. code-block:: yaml

# default.yaml
param1: 0
param2: 2
sweep@dict: None

# sweep.yaml
sweep@dict:
metric.name: accuracy
metric.goal: max
method: bayes
parameters:
param1:
min: 0
max: 50

The ``swep`` parameter is a dict and not a sub-config.

Warning
-------

* Processings are not applied in the dict keys. In particular,
the tags are not used and not removed.
* The tag ``@dict`` must be added at the key containing
the dict every time you want to modify the dict.
"""

class PseudoDict:
"""Object containing a dict that dodges flattening."""

def __init__(self, dict_: dict):
self.dict = dict_

def __repr__(self) -> str:
"""Representation."""
return f"PseudoDict({self.dict})"

def __str__(self) -> str:
"""Representation as string."""
return f"PseudoDict({self.dict})"

def __init__(self) -> None:
super().__init__()
self.premerge_order = -30.0
self.endbuild_order = 0.0
self.presave_order = -30.0
self.keys_with_dict: Set[str] = set()

def premerge(self, flat_config: Config) -> Config:
"""Pre-merge processing."""
keys = list(flat_config.dict.keys())
splitter: TypeSplitDict = {}
for key in keys:
if is_tag_in(key, "dict", full_key=True):
splitter = self._split_dict_key(splitter, key, flat_config.dict[key])
del flat_config.dict[key]
new_dict = {}
for key, values in splitter.items():
new_dict[key] = self.PseudoDict(unflatten(dict(values)))
self.keys_with_dict.add(clean_all_tags(key))
flat_config.dict.update(new_dict)
return flat_config

def endbuild(self, flat_config: Config) -> Config:
"""End-build processing."""
for key in flat_config.dict:
if isinstance(flat_config.dict[key], self.PseudoDict):
flat_config.dict[key] = flat_config.dict[key].dict
return flat_config

def presave(self, flat_config: Config) -> Config:
"""Pre-save processing."""
keys = list(flat_config.dict.keys())
for key in keys:
if key.startswith(tuple(self.keys_with_dict)):
for key_dict in self.keys_with_dict:
# Add the tag @dict to the key to keep the information
if key.startswith(key_dict):
new_key = key_dict + "@dict" + key[len(key_dict) :]
flat_config.dict[new_key] = flat_config.dict[key]
del flat_config.dict[key]
break
return flat_config

def _split_dict_key(
self,
splitter: TypeSplitDict,
flat_key: str,
value: Any,
) -> TypeSplitDict:
"""Split a key by @dict."""
split_dict = flat_key.split("@dict")
# Handle the case where there is another @dict in the key
split_dict = [split_dict[0]] + ["@dict".join(split_dict[1:])]
# Include the other tags in the key
split_dot = split_dict[1].split(".")
split_dot = [split_dot[0]] + [".".join(split_dot[1:])]

main_key = split_dict[0] + split_dot[0]
dict_key = split_dot[1]
if main_key not in splitter:
splitter[main_key] = [(dict_key, value)]
else:
splitter[main_key].append((dict_key, value))
return splitter


class ProcessCheckTags(Processing):
"""Raise an error if a tag is present in a key after pre-merging processes.

Expand Down Expand Up @@ -836,5 +960,6 @@ def __init__(self) -> None:
ProcessTyping(),
ProcessSelect(),
ProcessDelete(),
ProcessDict(),
ProcessNew(),
]
14 changes: 9 additions & 5 deletions docs/edge_cases.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,21 @@ as config attributes (with dots).
In the context of this package, dictionaries are treated as sub-configurations,
which means that modifying or adding keys directly in the additional configs may
not be possible (because only the merge of default configuration allow adding new keys).
If you need to modify or add keys within a dictionary, consider enclosing it in a list.
If you need to have a dictionary object where you can modify the keys, consider
using the `@dict` tag:

For instance:

```yaml
# default.yaml
logging:
metrics: ['train loss', 'val loss']
styles: [{'train loss': 'red', 'val loss': 'blue'}]
metrics: [train loss, val loss]
styles@dict: {train_loss: red, val_loss: blue}
# additional.yaml
logging:
metrics: ['train loss', 'val loss', 'val acc']
styles: [{'train loss': 'red', 'val loss': 'blue', 'val acc': 'cyan'}]
metrics: [train loss, val acc]
styles@dict: {train_loss: red, val_acc: cyan}
```

Like a sub-config, the dictionary can be accessed with the dot notation like this:
`config.logging.styles.val_acc` and will return "cyan".
3 changes: 3 additions & 0 deletions docs/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ The default tags include:
present in the default config(s). It can be used for single parameter or a
sub-config. Disclaimer: it is preferable to have exhaustive default config(s)
instead of abusing this tag for readability and for security concerning typos.
* `@dict`: This tag allows to have a dictionary object instead of a sub-config
where you can modify the keys (see the
[*Edge cases*](https://cliconfig.readthedocs.io/en/latest/edge_cases.html) section)

The tags are applied in a particular order that ensure no conflict between them.

Expand Down
2 changes: 1 addition & 1 deletion tests/configs/integration/test2/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ models:
vit_b16_cfg_path@merge_add: tests/configs/integration/test2/models/vit_b16.yaml

---

data:
dataset_path@type:str: ../data
dataset_cfg_path@merge_add: tests/configs/integration/test2/data.yaml
Expand All @@ -17,3 +16,4 @@ train:
type@optim_type: SGD
lr@type:float: 0.01
momentum@type:float|int: 0.9
sweep_cfg: None
5 changes: 5 additions & 0 deletions tests/configs/integration/test2/exp1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,8 @@ models.vit_b16:
in_size: !copy@type:int data.data_size
dim: 512
train.n_epochs: 30
sweep_cfg@dict:
method: grid
models.dim:
- 512
- 1024
5 changes: 5 additions & 0 deletions tests/configs/integration/test2/exp2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,8 @@ data.data_size: 512
data.dataset_path: ../../mydata
data.augmentation: [RandomHorizontalFlip, RandomVerticalFlip]
1@merge_add: !delete tests/configs/integration/test2/exp3.yaml
sweep_cfg@dict:
method: bayes
data.data_size:
- 512
- 1024
4 changes: 4 additions & 0 deletions tests/integration/test_inte_multiple_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ def func_optim_type(x: str) -> str:
"momentum": 0,
},
},
"sweep_cfg": {
"method": "bayes",
"data": {"data_size": [512, 1024]},
},
"metadata": {
"exp_details": {
"goal": "Test multiple processings",
Expand Down
48 changes: 47 additions & 1 deletion tests/unit/processing/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import pytest_check as check

from cliconfig.base import Config
from cliconfig.dict_routines import flatten
from cliconfig.dict_routines import flatten, unflatten
from cliconfig.process_routines import merge_flat_processing
from cliconfig.processing.builtin import (
DefaultProcessings,
ProcessCheckTags,
ProcessCopy,
ProcessDef,
ProcessDelete,
ProcessDict,
ProcessMerge,
ProcessNew,
ProcessSelect,
Expand Down Expand Up @@ -405,6 +407,50 @@ def test_process_new() -> None:
)


def test_process_dict() -> None:
"""Test ProcessDict."""
processing = ProcessDict()
flat_dict1 = {
"subconfig.mydict@dict.param1": 1,
"subconfig.mydict@dict.param2": 2,
"subconfig.mydict@dict.foo.param3": 3,
}
config1 = Config(flat_dict1, [processing])
config1 = processing.premerge(config1)
config1.dict = flatten(config1.dict)

check.equal(
config1.dict["subconfig.mydict"].dict,
{"param1": 1, "param2": 2, "foo": {"param3": 3}},
)
flat_dict2 = {
"subconfig.mydict@dict.a": 1,
"subconfig.mydict@dict.b": 2,
}
config2 = Config(flat_dict2)
config = merge_flat_processing(config1, config2, allow_new_keys=False)
check.equal(
config.dict["subconfig.mydict"].dict,
{"a": 1, "b": 2},
)
config = processing.endbuild(config)
config.dict = unflatten(config.dict)
check.equal(
config.dict,
{"subconfig": {"mydict": {"a": 1, "b": 2}}},
)
config.dict = flatten(config.dict)
config = processing.presave(config)
check.equal(
config.dict,
{"subconfig.mydict@dict.a": 1, "subconfig.mydict@dict.b": 2},
)

pseudodict = processing.PseudoDict({"a": 1})
check.equal(repr(pseudodict), "PseudoDict({'a': 1})")
check.equal(str(pseudodict), "PseudoDict({'a': 1})")


def test_process_check_tags() -> None:
"""Test ProcessCheckTags."""
processing = ProcessCheckTags()
Expand Down
Loading