diff --git a/ml_collections/config_flags/config_flags.py b/ml_collections/config_flags/config_flags.py index 8569668..553808e 100644 --- a/ml_collections/config_flags/config_flags.py +++ b/ml_collections/config_flags/config_flags.py @@ -83,12 +83,32 @@ def flag_type(self): return 'config_literal' +class _ConfigDictParser(flags.ArgumentParser): + """Parser for ConfigDict values.""" + + def parse(self, argument: str) -> config_dict.ConfigDict: + try: + value = ast.literal_eval(argument) + except (SyntaxError, ValueError) as e: + # Otherwise, the flag is a string: `--cfg.value="my_string"` + raise ValueError( + f'Failed to parse {argument!r} as a ConfigDict: {e!r}' + ) from None + + if not isinstance(value, dict): + raise ValueError( + f'Failed to parse {argument!r} as a ConfigDict: Is not a dict.' + ) + return config_dict.ConfigDict(value) + + _FIELD_TYPE_TO_PARSER = { float: flags.FloatParser(), bool: flags.BooleanParser(), tuple: tuple_parser.TupleParser(), int: flags.IntegerParser(), str: flags.ArgumentParser(), + config_dict.ConfigDict: _ConfigDictParser(), object: _LiteralParser(), } @@ -830,7 +850,7 @@ def _parse(self, argument): if parser: if not isinstance(parser, tuple_parser.TupleParser): - if isinstance(parser, _LiteralParser): + if isinstance(parser, (_LiteralParser, _ConfigDictParser)): # We do not pass the default to `_ConfigFieldFlag`, otherwise # `_LiteralParser.parse(default)` is called with `default`, # which would try to parse string. diff --git a/ml_collections/config_flags/tests/config_overriding_test.py b/ml_collections/config_flags/tests/config_overriding_test.py index 2cc0277..68a1478 100644 --- a/ml_collections/config_flags/tests/config_overriding_test.py +++ b/ml_collections/config_flags/tests/config_overriding_test.py @@ -493,6 +493,21 @@ def testLoadingLockedConfigDict(self): self.assertFalse(values.test_config.is_locked) self.assertFalse(values.test_config.nested_configdict.is_locked) + def testOverridingNestedConfigDict(self): + """Tests overriding of ConfigDict fields.""" + + config_flag = f'--test_config={_CONFIGDICT_CONFIG_FILE}' + values = _parse_flags( + f'./program {config_flag}' + ' --test_config.nested_configdict="{\\"a\\": True, \\"b\\": 123}"' + ) + self.assertEqual(values.test_config.nested_configdict.a, True) + self.assertEqual(values.test_config.nested_configdict.b, 123) + self.assertEqual( + dict(values.test_config.nested_configdict.items()), + {'a': True, 'b': 123}, + ) + @parameterized.named_parameters( ('WithTwoDashesAndEqual', '--test_config={}'.format(_TEST_DIRECTORY)), ('WithTwoDashes', '--test_config {}'.format(_TEST_DIRECTORY)),