Skip to content

Commit

Permalink
Support overwriting ConfigDict through flags
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696916439
  • Loading branch information
Conchylicultor authored and The ml_collections Authors committed Nov 15, 2024
1 parent 7a5dd45 commit c55ea08
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
22 changes: 21 additions & 1 deletion ml_collections/config_flags/config_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,32 @@ def flag_type(self):
return 'config_literal'


class _ConfigDictParser(flags.ArgumentParser):
"""Parser for ConfigDict values."""

def parse(self, argument: str) -> config_dict.ConfigDict:
try:
value = ast.literal_eval(argument)
except (SyntaxError, ValueError) as e:
# Otherwise, the flag is a string: `--cfg.value="my_string"`
raise ValueError(
f'Failed to parse {argument!r} as a ConfigDict: {e!r}'
) from None

if not isinstance(value, dict):
raise ValueError(
f'Failed to parse {argument!r} as a ConfigDict: Is not a dict.'
)
return config_dict.ConfigDict(value)


_FIELD_TYPE_TO_PARSER = {
float: flags.FloatParser(),
bool: flags.BooleanParser(),
tuple: tuple_parser.TupleParser(),
int: flags.IntegerParser(),
str: flags.ArgumentParser(),
config_dict.ConfigDict: _ConfigDictParser(),
object: _LiteralParser(),
}

Expand Down Expand Up @@ -830,7 +850,7 @@ def _parse(self, argument):

if parser:
if not isinstance(parser, tuple_parser.TupleParser):
if isinstance(parser, _LiteralParser):
if isinstance(parser, (_LiteralParser, _ConfigDictParser)):
# We do not pass the default to `_ConfigFieldFlag`, otherwise
# `_LiteralParser.parse(default)` is called with `default`,
# which would try to parse string.
Expand Down
15 changes: 15 additions & 0 deletions ml_collections/config_flags/tests/config_overriding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,21 @@ def testLoadingLockedConfigDict(self):
self.assertFalse(values.test_config.is_locked)
self.assertFalse(values.test_config.nested_configdict.is_locked)

def testOverridingNestedConfigDict(self):
"""Tests overriding of ConfigDict fields."""

config_flag = f'--test_config={_CONFIGDICT_CONFIG_FILE}'
values = _parse_flags(
f'./program {config_flag}'
' --test_config.nested_configdict="{\\"a\\": True, \\"b\\": 123}"'
)
self.assertEqual(values.test_config.nested_configdict.a, True)
self.assertEqual(values.test_config.nested_configdict.b, 123)
self.assertEqual(
dict(values.test_config.nested_configdict.items()),
{'a': True, 'b': 123},
)

@parameterized.named_parameters(
('WithTwoDashesAndEqual', '--test_config={}'.format(_TEST_DIRECTORY)),
('WithTwoDashes', '--test_config {}'.format(_TEST_DIRECTORY)),
Expand Down

0 comments on commit c55ea08

Please sign in to comment.