From 91d9d391a48450047ef495d211805e02816ed474 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Wed, 24 Aug 2022 11:32:21 +0200 Subject: [PATCH 1/4] Remove redundant tests. Add confection to requirement.txt and setup.cfg. Adjust cnfig.py. --- requirements.txt | 1 + setup.cfg | 1 + thinc/config.py | 1038 +----------------------------- thinc/tests/test_config.py | 1240 ------------------------------------ 4 files changed, 5 insertions(+), 2275 deletions(-) diff --git a/requirements.txt b/requirements.txt index a1378a3fc..6a04190da 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,3 +34,4 @@ nbformat>=5.0.4,<5.2.0 # Test to_disk/from_disk against pathlib.Path subclasses pathy>=0.3.5 black>=22.0,<23.0 +confection>=0.0.1,<1.0.0 \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 198823d14..2c147b995 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,6 +45,7 @@ install_requires = wasabi>=0.8.1,<1.1.0 srsly>=2.4.0,<3.0.0 catalogue>=2.0.4,<2.1.0 + confection>=0.0.1,<1.0.0 # Third-party dependencies setuptools numpy>=1.15.0 diff --git a/thinc/config.py b/thinc/config.py index 837f91b76..8affebec6 100644 --- a/thinc/config.py +++ b/thinc/config.py @@ -1,701 +1,10 @@ -from typing import Union, Dict, Any, Optional, List, Tuple, Callable, Type, Mapping -from typing import Iterable, Sequence, cast -from types import GeneratorType -from dataclasses import dataclass -from configparser import ConfigParser, ExtendedInterpolation, MAX_INTERPOLATION_DEPTH -from configparser import InterpolationMissingOptionError, InterpolationSyntaxError -from configparser import NoSectionError, NoOptionError, InterpolationDepthError -from configparser import ParsingError -from pathlib import Path -from pydantic import BaseModel, create_model, ValidationError, Extra -from pydantic.main import ModelMetaclass -from pydantic.fields import ModelField -from wasabi import table -import srsly import catalogue -import inspect -import io -import numpy -import copy -import re - +import confection +from confection import Config, ConfigValidationError from .types import Decorator -# Field used for positional arguments, e.g. [section.*.xyz]. The alias is -# required for the schema (shouldn't clash with user-defined arg names) -ARGS_FIELD = "*" -ARGS_FIELD_ALIAS = "VARIABLE_POSITIONAL_ARGS" -# Aliases for fields that would otherwise shadow pydantic attributes. Can be any -# string, so we're using name + space so it looks the same in error messages etc. -RESERVED_FIELDS = {"validate": "validate\u0020"} -# Internal prefix used to mark section references for custom interpolation -SECTION_PREFIX = "__SECTION__:" -# Values that shouldn't be loaded during interpolation because it'd cause -# even explicit string values to be incorrectly parsed as bools/None etc. -JSON_EXCEPTIONS = ("true", "false", "null") -# Regex to detect whether a value contains a variable -VARIABLE_RE = re.compile(r"\$\{[\w\.:]+\}") - - -class CustomInterpolation(ExtendedInterpolation): - def before_read(self, parser, section, option, value): - # If we're dealing with a quoted string as the interpolation value, - # make sure we load and unquote it so we don't end up with '"value"' - try: - json_value = srsly.json_loads(value) - if isinstance(json_value, str) and json_value not in JSON_EXCEPTIONS: - value = json_value - except Exception: - pass - return super().before_read(parser, section, option, value) - - def before_get(self, parser, section, option, value, defaults): - # Mostly copy-pasted from the built-in configparser implementation. - L = [] - self.interpolate(parser, option, L, value, section, defaults, 1) - return "".join(L) - - def interpolate(self, parser, option, accum, rest, section, map, depth): - # Mostly copy-pasted from the built-in configparser implementation. - # We need to overwrite this method so we can add special handling for - # block references :( All values produced here should be strings – - # we need to wait until the whole config is interpreted anyways so - # filling in incomplete values here is pointless. All we need is the - # section reference so we can fetch it later. - rawval = parser.get(section, option, raw=True, fallback=rest) - if depth > MAX_INTERPOLATION_DEPTH: - raise InterpolationDepthError(option, section, rawval) - while rest: - p = rest.find("$") - if p < 0: - accum.append(rest) - return - if p > 0: - accum.append(rest[:p]) - rest = rest[p:] - # p is no longer used - c = rest[1:2] - if c == "$": - accum.append("$") - rest = rest[2:] - elif c == "{": - # We want to treat both ${a:b} and ${a.b} the same - m = self._KEYCRE.match(rest) - if m is None: - err = f"bad interpolation variable reference {rest}" - raise InterpolationSyntaxError(option, section, err) - orig_var = m.group(1) - path = orig_var.replace(":", ".").rsplit(".", 1) - rest = rest[m.end() :] - sect = section - opt = option - try: - if len(path) == 1: - opt = parser.optionxform(path[0]) - if opt in map: - v = map[opt] - else: - # We have block reference, store it as a special key - section_name = parser[parser.optionxform(path[0])]._name - v = self._get_section_name(section_name) - elif len(path) == 2: - sect = path[0] - opt = parser.optionxform(path[1]) - fallback = "__FALLBACK__" - v = parser.get(sect, opt, raw=True, fallback=fallback) - # If a variable doesn't exist, try again and treat the - # reference as a section - if v == fallback: - v = self._get_section_name(parser[f"{sect}.{opt}"]._name) - else: - err = f"More than one ':' found: {rest}" - raise InterpolationSyntaxError(option, section, err) - except (KeyError, NoSectionError, NoOptionError): - raise InterpolationMissingOptionError( - option, section, rawval, orig_var - ) from None - if "$" in v: - new_map = dict(parser.items(sect, raw=True)) - self.interpolate(parser, opt, accum, v, sect, new_map, depth + 1) - else: - accum.append(v) - else: - err = "'$' must be followed by '$' or '{', " "found: %r" % (rest,) - raise InterpolationSyntaxError(option, section, err) - - def _get_section_name(self, name: str) -> str: - """Generate the name of a section. Note that we use a quoted string here - so we can use section references within lists and load the list as - JSON. Since section references can't be used within strings, we don't - need the quoted vs. unquoted distinction like we do for variables. - - Examples (assuming section = {"foo": 1}): - - value: ${section.foo} -> value: 1 - - value: "hello ${section.foo}" -> value: "hello 1" - - value: ${section} -> value: {"foo": 1} - - value: "${section}" -> value: {"foo": 1} - - value: "hello ${section}" -> invalid - """ - return f'"{SECTION_PREFIX}{name}"' - - -def get_configparser(interpolate: bool = True): - config = ConfigParser(interpolation=CustomInterpolation() if interpolate else None) - # Preserve case of keys: https://stackoverflow.com/a/1611877/6400719 - config.optionxform = str # type: ignore - return config - - -class Config(dict): - """This class holds the model and training configuration and can load and - save the TOML-style configuration format from/to a string, file or bytes. - The Config class is a subclass of dict and uses Python's ConfigParser - under the hood. - """ - - is_interpolated: bool - - def __init__( - self, - data: Optional[Union[Dict[str, Any], "ConfigParser", "Config"]] = None, - *, - is_interpolated: Optional[bool] = None, - section_order: Optional[List[str]] = None, - ) -> None: - """Initialize a new Config object with optional data.""" - dict.__init__(self) - if data is None: - data = {} - if not isinstance(data, (dict, Config, ConfigParser)): - raise ValueError( - f"Can't initialize Config with data. Expected dict, Config or " - f"ConfigParser but got: {type(data)}" - ) - # Whether the config has been interpolated. We can use this to check - # whether we need to interpolate again when it's resolved. We assume - # that a config is interpolated by default. - if is_interpolated is not None: - self.is_interpolated = is_interpolated - elif isinstance(data, Config): - self.is_interpolated = data.is_interpolated - else: - self.is_interpolated = True - if section_order is not None: - self.section_order = section_order - elif isinstance(data, Config): - self.section_order = data.section_order - else: - self.section_order = [] - # Update with data - self.update(self._sort(data)) - - def interpolate(self) -> "Config": - """Interpolate a config. Returns a copy of the object.""" - # This is currently the most effective way because we need our custom - # to_str logic to run in order to re-serialize the values so we can - # interpolate them again. ConfigParser.read_dict will just call str() - # on all values, which isn't enough. - return Config().from_str(self.to_str()) - - def interpret_config(self, config: "ConfigParser") -> None: - """Interpret a config, parse nested sections and parse the values - as JSON. Mostly used internally and modifies the config in place. - """ - self._validate_sections(config) - # Sort sections by depth, so that we can iterate breadth-first. This - # allows us to check that we're not expanding an undefined block. - get_depth = lambda item: len(item[0].split(".")) - for section, values in sorted(config.items(), key=get_depth): - if section == "DEFAULT": - # Skip [DEFAULT] section so it doesn't cause validation error - continue - parts = section.split(".") - node = self - for part in parts[:-1]: - if part == "*": - node = node.setdefault(part, {}) - elif part not in node: - err_title = f"Error parsing config section. Perhaps a section name is wrong?" - err = [{"loc": parts, "msg": f"Section '{part}' is not defined"}] - raise ConfigValidationError( - config=self, errors=err, title=err_title - ) - else: - node = node[part] - if not isinstance(node, dict): - # Happens if both value *and* subsection were defined for a key - err = [{"loc": parts, "msg": "found conflicting values"}] - err_cfg = f"{self}\n{({part: dict(values)})}" - raise ConfigValidationError(config=err_cfg, errors=err) - # Set the default section - node = node.setdefault(parts[-1], {}) - if not isinstance(node, dict): - # Happens if both value *and* subsection were defined for a key - err = [{"loc": parts, "msg": "found conflicting values"}] - err_cfg = f"{self}\n{({part: dict(values)})}" - raise ConfigValidationError(config=err_cfg, errors=err) - try: - keys_values = list(values.items()) - except InterpolationMissingOptionError as e: - raise ConfigValidationError(desc=f"{e}") from None - for key, value in keys_values: - config_v = config.get(section, key) - node[key] = self._interpret_value(config_v) - self.replace_section_refs(self) - - def replace_section_refs( - self, config: Union[Dict[str, Any], "Config"], parent: str = "" - ) -> None: - """Replace references to section blocks in the final config.""" - for key, value in config.items(): - key_parent = f"{parent}.{key}".strip(".") - if isinstance(value, dict): - self.replace_section_refs(value, parent=key_parent) - elif isinstance(value, list): - config[key] = [ - self._get_section_ref(v, parent=[parent, key]) for v in value - ] - else: - config[key] = self._get_section_ref(value, parent=[parent, key]) - - def _interpret_value(self, value: Any) -> Any: - """Interpret a single config value.""" - result = try_load_json(value) - # If value is a string and it contains a variable, use original value - # (not interpreted string, which could lead to double quotes: - # ${x.y} -> "${x.y}" -> "'${x.y}'"). Make sure to check it's a string, - # so we're not keeping lists as strings. - # NOTE: This currently can't handle uninterpolated values like [${x.y}]! - if isinstance(result, str) and VARIABLE_RE.search(value): - result = value - if isinstance(result, list): - return [self._interpret_value(v) for v in result] - return result - - def _get_section_ref(self, value: Any, *, parent: List[str] = []) -> Any: - """Get a single section reference.""" - if isinstance(value, str) and value.startswith(f'"{SECTION_PREFIX}'): - value = try_load_json(value) - if isinstance(value, str) and value.startswith(SECTION_PREFIX): - parts = value.replace(SECTION_PREFIX, "").split(".") - result = self - for item in parts: - try: - result = result[item] - except (KeyError, TypeError): # This should never happen - err_title = "Error parsing reference to config section" - err_msg = f"Section '{'.'.join(parts)}' is not defined" - err = [{"loc": parts, "msg": err_msg}] - raise ConfigValidationError( - config=self, errors=err, title=err_title - ) from None - return result - elif isinstance(value, str) and SECTION_PREFIX in value: - # String value references a section (either a dict or return - # value of promise). We can't allow this, since variables are - # always interpolated *before* configs are resolved. - err_desc = ( - "Can't reference whole sections or return values of function " - "blocks inside a string or list\n\nYou can change your variable to " - "reference a value instead. Keep in mind that it's not " - "possible to interpolate the return value of a registered " - "function, since variables are interpolated when the config " - "is loaded, and registered functions are resolved afterwards." - ) - err = [{"loc": parent, "msg": "uses section variable in string or list"}] - raise ConfigValidationError(errors=err, desc=err_desc) - return value - - def copy(self) -> "Config": - """Deepcopy the config.""" - try: - config = copy.deepcopy(self) - except Exception as e: - raise ValueError(f"Couldn't deep-copy config: {e}") from e - return Config( - config, - is_interpolated=self.is_interpolated, - section_order=self.section_order, - ) - - def merge( - self, updates: Union[Dict[str, Any], "Config"], remove_extra: bool = False - ) -> "Config": - """Deep merge the config with updates, using current as defaults.""" - defaults = self.copy() - updates = Config(updates).copy() - merged = deep_merge_configs(updates, defaults, remove_extra=remove_extra) - return Config( - merged, - is_interpolated=defaults.is_interpolated and updates.is_interpolated, - section_order=defaults.section_order, - ) - - def _sort( - self, data: Union["Config", "ConfigParser", Dict[str, Any]] - ) -> Dict[str, Any]: - """Sort sections using the currently defined sort order. Sort - sections by index on section order, if available, then alphabetic, and - account for subsections, which should always follow their parent. - """ - sort_map = {section: i for i, section in enumerate(self.section_order)} - sort_key = lambda x: ( - sort_map.get(x[0].split(".")[0], len(sort_map)), - _mask_positional_args(x[0]), - ) - return dict(sorted(data.items(), key=sort_key)) - - def _set_overrides(self, config: "ConfigParser", overrides: Dict[str, Any]) -> None: - """Set overrides in the ConfigParser before config is interpreted.""" - err_title = "Error parsing config overrides" - for key, value in overrides.items(): - err_msg = "not a section value that can be overwritten" - err = [{"loc": key.split("."), "msg": err_msg}] - if "." not in key: - raise ConfigValidationError(errors=err, title=err_title) - section, option = key.rsplit(".", 1) - # Check for section and accept if option not in config[section] - if section not in config: - raise ConfigValidationError(errors=err, title=err_title) - config.set(section, option, try_dump_json(value, overrides)) - - def _validate_sections(self, config: "ConfigParser") -> None: - # If the config defines top-level properties that are not sections (e.g. - # if config was constructed from dict), those values would be added as - # [DEFAULTS] and included in *every other section*. This is usually not - # what we want and it can lead to very confusing results. - default_section = config.defaults() - if default_section: - err_title = "Found config values without a top-level section" - err_msg = "not part of a section" - err = [{"loc": [k], "msg": err_msg} for k in default_section] - raise ConfigValidationError(errors=err, title=err_title) - - def from_str( - self, text: str, *, interpolate: bool = True, overrides: Dict[str, Any] = {} - ) -> "Config": - """Load the config from a string.""" - config = get_configparser(interpolate=interpolate) - if overrides: - config = get_configparser(interpolate=False) - try: - config.read_string(text) - except ParsingError as e: - desc = f"Make sure the sections and values are formatted correctly.\n\n{e}" - raise ConfigValidationError(desc=desc) from None - config._sections = self._sort(config._sections) - self._set_overrides(config, overrides) - self.clear() - self.interpret_config(config) - if overrides and interpolate: - # do the interpolation. Avoids recursion because the new call from_str call will have overrides as empty - self = self.interpolate() - self.is_interpolated = interpolate - return self - - def to_str(self, *, interpolate: bool = True) -> str: - """Write the config to a string.""" - flattened = get_configparser(interpolate=interpolate) - queue: List[Tuple[tuple, "Config"]] = [(tuple(), self)] - for path, node in queue: - section_name = ".".join(path) - is_kwarg = path and path[-1] != "*" - if is_kwarg and not flattened.has_section(section_name): - # Always create sections for non-'*' sections, not only if - # they have leaf entries, as we don't want to expand - # blocks that are undefined - flattened.add_section(section_name) - for key, value in node.items(): - if hasattr(value, "items"): - # Reference to a function with no arguments, serialize - # inline as a dict and don't create new section - if registry.is_promise(value) and len(value) == 1 and is_kwarg: - flattened.set(section_name, key, try_dump_json(value, node)) - else: - queue.append((path + (key,), value)) - else: - flattened.set(section_name, key, try_dump_json(value, node)) - # Order so subsection follow parent (not all sections, then all subs etc.) - flattened._sections = self._sort(flattened._sections) - self._validate_sections(flattened) - string_io = io.StringIO() - flattened.write(string_io) - return string_io.getvalue().strip() - - def to_bytes(self, *, interpolate: bool = True) -> bytes: - """Serialize the config to a byte string.""" - return self.to_str(interpolate=interpolate).encode("utf8") - - def from_bytes( - self, - bytes_data: bytes, - *, - interpolate: bool = True, - overrides: Dict[str, Any] = {}, - ) -> "Config": - """Load the config from a byte string.""" - return self.from_str( - bytes_data.decode("utf8"), interpolate=interpolate, overrides=overrides - ) - - def to_disk(self, path: Union[str, Path], *, interpolate: bool = True): - """Serialize the config to a file.""" - path = Path(path) if isinstance(path, str) else path - with path.open("w", encoding="utf8") as file_: - file_.write(self.to_str(interpolate=interpolate)) - - def from_disk( - self, - path: Union[str, Path], - *, - interpolate: bool = True, - overrides: Dict[str, Any] = {}, - ) -> "Config": - """Load config from a file.""" - path = Path(path) if isinstance(path, str) else path - with path.open("r", encoding="utf8") as file_: - text = file_.read() - return self.from_str(text, interpolate=interpolate, overrides=overrides) - - -def _mask_positional_args(name: str) -> List[Optional[str]]: - """Create a section name representation that masks names - of positional arguments to retain their order in sorts.""" - - stable_name = cast(List[Optional[str]], name.split(".")) - - # Remove names of sections that are a positional argument. - for i in range(1, len(stable_name)): - if stable_name[i - 1] == "*": - stable_name[i] = None - - return stable_name - - -def try_load_json(value: str) -> Any: - """Load a JSON string if possible, otherwise default to original value.""" - try: - return srsly.json_loads(value) - except Exception: - return value - - -def try_dump_json(value: Any, data: Union[Dict[str, dict], Config, str] = "") -> str: - """Dump a config value as JSON and output user-friendly error if it fails.""" - # Special case if we have a variable: it's already a string so don't dump - # to preserve ${x:y} vs. "${x:y}" - if isinstance(value, str) and VARIABLE_RE.search(value): - return value - if isinstance(value, str) and value.replace(".", "", 1).isdigit(): - # Work around values that are strings but numbers - value = f'"{value}"' - try: - return srsly.json_dumps(value) - except Exception as e: - err_msg = ( - f"Couldn't serialize config value of type {type(value)}: {e}. Make " - f"sure all values in your config are JSON-serializable. If you want " - f"to include Python objects, use a registered function that returns " - f"the object instead." - ) - raise ConfigValidationError(config=data, desc=err_msg) from e - - -def deep_merge_configs( - config: Union[Dict[str, Any], Config], - defaults: Union[Dict[str, Any], Config], - *, - remove_extra: bool = False, -) -> Union[Dict[str, Any], Config]: - """Deep merge two configs.""" - if remove_extra: - # Filter out values in the original config that are not in defaults - keys = list(config.keys()) - for key in keys: - if key not in defaults: - del config[key] - for key, value in defaults.items(): - if isinstance(value, dict): - node = config.setdefault(key, {}) - if not isinstance(node, dict): - continue - value_promises = [k for k in value if k.startswith("@")] - value_promise = value_promises[0] if value_promises else None - node_promises = [k for k in node if k.startswith("@")] if node else [] - node_promise = node_promises[0] if node_promises else None - # We only update the block from defaults if it refers to the same - # registered function - if ( - value_promise - and node_promise - and ( - value_promise in node - and node[value_promise] != value[value_promise] - ) - ): - continue - if node_promise and ( - node_promise not in value or node[node_promise] != value[node_promise] - ): - continue - defaults = deep_merge_configs(node, value, remove_extra=remove_extra) - elif key not in config: - config[key] = value - return config - - -class ConfigValidationError(ValueError): - def __init__( - self, - *, - config: Optional[Union[Config, Dict[str, Dict[str, Any]], str]] = None, - errors: Union[Sequence[Mapping[str, Any]], Iterable[Dict[str, Any]]] = tuple(), - title: Optional[str] = "Config validation error", - desc: Optional[str] = None, - parent: Optional[str] = None, - show_config: bool = True, - ) -> None: - """Custom error for validating configs. - - config (Union[Config, Dict[str, Dict[str, Any]], str]): The - config the validation error refers to. - errors (Union[Sequence[Mapping[str, Any]], Iterable[Dict[str, Any]]]): - A list of errors as dicts with keys "loc" (list of strings - describing the path of the value), "msg" (validation message - to show) and optional "type" (mostly internals). - Same format as produced by pydantic's validation error (e.errors()). - title (str): The error title. - desc (str): Optional error description, displayed below the title. - parent (str): Optional parent to use as prefix for all error locations. - For example, parent "element" will result in "element -> a -> b". - show_config (bool): Whether to print the whole config with the error. - - ATTRIBUTES: - config (Union[Config, Dict[str, Dict[str, Any]], str]): The config. - errors (Iterable[Dict[str, Any]]): The errors. - error_types (Set[str]): All "type" values defined in the errors, if - available. This is most relevant for the pydantic errors that define - types like "type_error.integer". This attribute makes it easy to - check if a config validation error includes errors of a certain - type, e.g. to log additional information or custom help messages. - title (str): The title. - desc (str): The description. - parent (str): The parent. - show_config (bool): Whether to show the config. - text (str): The formatted error text. - """ - self.config = config - self.errors = errors - self.title = title - self.desc = desc - self.parent = parent - self.show_config = show_config - self.error_types = set() - for error in self.errors: - err_type = error.get("type") - if err_type: - self.error_types.add(err_type) - self.text = self._format() - ValueError.__init__(self, self.text) - - @classmethod - def from_error( - cls, - err: "ConfigValidationError", - title: Optional[str] = None, - desc: Optional[str] = None, - parent: Optional[str] = None, - show_config: Optional[bool] = None, - ) -> "ConfigValidationError": - """Create a new ConfigValidationError based on an existing error, e.g. - to re-raise it with different settings. If no overrides are provided, - the values from the original error are used. - - err (ConfigValidationError): The original error. - title (str): Overwrite error title. - desc (str): Overwrite error description. - parent (str): Overwrite error parent. - show_config (bool): Overwrite whether to show config. - RETURNS (ConfigValidationError): The new error. - """ - return cls( - config=err.config, - errors=err.errors, - title=title if title is not None else err.title, - desc=desc if desc is not None else err.desc, - parent=parent if parent is not None else err.parent, - show_config=show_config if show_config is not None else err.show_config, - ) - - def _format(self) -> str: - """Format the error message.""" - loc_divider = "->" - data = [] - for error in self.errors: - err_loc = f" {loc_divider} ".join([str(p) for p in error.get("loc", [])]) - if self.parent: - err_loc = f"{self.parent} {loc_divider} {err_loc}" - data.append((err_loc, error.get("msg"))) - result = [] - if self.title: - result.append(self.title) - if self.desc: - result.append(self.desc) - if data: - result.append(table(data)) - if self.config and self.show_config: - result.append(f"{self.config}") - return "\n\n" + "\n".join(result) - - -def alias_generator(name: str) -> str: - """Generate field aliases in promise schema.""" - # Underscore fields are not allowed in model, so use alias - if name == ARGS_FIELD_ALIAS: - return ARGS_FIELD - # Auto-alias fields that shadow base model attributes - if name in RESERVED_FIELDS: - return RESERVED_FIELDS[name] - return name - - -def copy_model_field(field: ModelField, type_: Any) -> ModelField: - """Copy a model field and assign a new type, e.g. to accept an Any type - even though the original value is typed differently. - """ - return ModelField( - name=field.name, - type_=type_, - class_validators=field.class_validators, - model_config=field.model_config, - default=field.default, - default_factory=field.default_factory, - required=field.required, - ) - - -class EmptySchema(BaseModel): - class Config: - extra = "allow" - arbitrary_types_allowed = True - - -class _PromiseSchemaConfig: - extra = "forbid" - arbitrary_types_allowed = True - alias_generator = alias_generator - - -@dataclass -class Promise: - registry: str - name: str - args: List[str] - kwargs: Dict[str, Any] - - -class registry(object): +class registry(confection.registry): # fmt: off optimizers: Decorator = catalogue.create("thinc", "optimizers", entry_points=True) schedules: Decorator = catalogue.create("thinc", "schedules", entry_points=True) @@ -716,346 +25,5 @@ def create(cls, registry_name: str, entry_points: bool = False) -> None: ) setattr(cls, registry_name, reg) - @classmethod - def has(cls, registry_name: str, func_name: str) -> bool: - """Check whether a function is available in a registry.""" - if not hasattr(cls, registry_name): - return False - reg = getattr(cls, registry_name) - return func_name in reg - - @classmethod - def get(cls, registry_name: str, func_name: str) -> Callable: - """Get a registered function from a given registry.""" - if not hasattr(cls, registry_name): - raise ValueError(f"Unknown registry: '{registry_name}'") - reg = getattr(cls, registry_name) - func = reg.get(func_name) - if func is None: - raise ValueError(f"Could not find '{func_name}' in '{registry_name}'") - return func - - @classmethod - def resolve( - cls, - config: Union[Config, Dict[str, Dict[str, Any]]], - *, - schema: Type[BaseModel] = EmptySchema, - overrides: Dict[str, Any] = {}, - validate: bool = True, - ) -> Dict[str, Any]: - resolved, _ = cls._make( - config, schema=schema, overrides=overrides, validate=validate, resolve=True - ) - return resolved - - @classmethod - def fill( - cls, - config: Union[Config, Dict[str, Dict[str, Any]]], - *, - schema: Type[BaseModel] = EmptySchema, - overrides: Dict[str, Any] = {}, - validate: bool = True, - ): - _, filled = cls._make( - config, schema=schema, overrides=overrides, validate=validate, resolve=False - ) - return filled - - @classmethod - def _make( - cls, - config: Union[Config, Dict[str, Dict[str, Any]]], - *, - schema: Type[BaseModel] = EmptySchema, - overrides: Dict[str, Any] = {}, - resolve: bool = True, - validate: bool = True, - ) -> Tuple[Dict[str, Any], Config]: - """Unpack a config dictionary and create two versions of the config: - a resolved version with objects from the registry created recursively, - and a filled version with all references to registry functions left - intact, but filled with all values and defaults based on the type - annotations. If validate=True, the config will be validated against the - type annotations of the registered functions referenced in the config - (if available) and/or the schema (if available). - """ - # Valid: {"optimizer": {"@optimizers": "my_cool_optimizer", "rate": 1.0}} - # Invalid: {"@optimizers": "my_cool_optimizer", "rate": 1.0} - if cls.is_promise(config): - err_msg = "The top-level config object can't be a reference to a registered function." - raise ConfigValidationError(config=config, errors=[{"msg": err_msg}]) - # If a Config was loaded with interpolate=False, we assume it needs to - # be interpolated first, otherwise we take it at face value - is_interpolated = not isinstance(config, Config) or config.is_interpolated - section_order = config.section_order if isinstance(config, Config) else None - orig_config = config - if not is_interpolated: - config = Config(orig_config).interpolate() - filled, _, resolved = cls._fill( - config, schema, validate=validate, overrides=overrides, resolve=resolve - ) - filled = Config(filled, section_order=section_order) - # Check that overrides didn't include invalid properties not in config - if validate: - cls._validate_overrides(filled, overrides) - # Merge the original config back to preserve variables if we started - # with a config that wasn't interpolated. Here, we prefer variables to - # allow auto-filling a non-interpolated config without destroying - # variable references. - if not is_interpolated: - filled = filled.merge( - Config(orig_config, is_interpolated=False), remove_extra=True - ) - return dict(resolved), filled - - @classmethod - def _fill( - cls, - config: Union[Config, Dict[str, Dict[str, Any]]], - schema: Type[BaseModel] = EmptySchema, - *, - validate: bool = True, - resolve: bool = True, - parent: str = "", - overrides: Dict[str, Dict[str, Any]] = {}, - ) -> Tuple[ - Union[Dict[str, Any], Config], Union[Dict[str, Any], Config], Dict[str, Any] - ]: - """Build three representations of the config: - 1. All promises are preserved (just like config user would provide). - 2. Promises are replaced by their return values. This is the validation - copy and will be parsed by pydantic. It lets us include hacks to - work around problems (e.g. handling of generators). - 3. Final copy with promises replaced by their return values. - """ - filled: Dict[str, Any] = {} - validation: Dict[str, Any] = {} - final: Dict[str, Any] = {} - for key, value in config.items(): - # If the field name is reserved, we use its alias for validation - v_key = RESERVED_FIELDS.get(key, key) - key_parent = f"{parent}.{key}".strip(".") - if key_parent in overrides: - value = overrides[key_parent] - config[key] = value - if cls.is_promise(value): - if key in schema.__fields__ and not resolve: - # If we're not resolving the config, make sure that the field - # expecting the promise is typed Any so it doesn't fail - # validation if it doesn't receive the function return value - field = schema.__fields__[key] - schema.__fields__[key] = copy_model_field(field, Any) - promise_schema = cls.make_promise_schema(value, resolve=resolve) - filled[key], validation[v_key], final[key] = cls._fill( - value, - promise_schema, - validate=validate, - resolve=resolve, - parent=key_parent, - overrides=overrides, - ) - reg_name, func_name = cls.get_constructor(final[key]) - args, kwargs = cls.parse_args(final[key]) - if resolve: - # Call the function and populate the field value. We can't - # just create an instance of the type here, since this - # wouldn't work for generics / more complex custom types - getter = cls.get(reg_name, func_name) - # We don't want to try/except this and raise our own error - # here, because we want the traceback if the function fails. - getter_result = getter(*args, **kwargs) - else: - # We're not resolving and calling the function, so replace - # the getter_result with a Promise class - getter_result = Promise( - registry=reg_name, name=func_name, args=args, kwargs=kwargs - ) - validation[v_key] = getter_result - final[key] = getter_result - if isinstance(validation[v_key], GeneratorType): - # If value is a generator we can't validate type without - # consuming it (which doesn't work if it's infinite – see - # schedule for examples). So we skip it. - validation[v_key] = [] - elif hasattr(value, "items"): - field_type = EmptySchema - if key in schema.__fields__: - field = schema.__fields__[key] - field_type = field.type_ - if not isinstance(field.type_, ModelMetaclass): - # If we don't have a pydantic schema and just a type - field_type = EmptySchema - filled[key], validation[v_key], final[key] = cls._fill( - value, - field_type, - validate=validate, - resolve=resolve, - parent=key_parent, - overrides=overrides, - ) - if key == ARGS_FIELD and isinstance(validation[v_key], dict): - # If the value of variable positional args is a dict (e.g. - # created via config blocks), only use its values - validation[v_key] = list(validation[v_key].values()) - final[key] = list(final[key].values()) - else: - filled[key] = value - # Prevent pydantic from consuming generator if part of a union - validation[v_key] = ( - value if not isinstance(value, GeneratorType) else [] - ) - final[key] = value - # Now that we've filled in all of the promises, update with defaults - # from schema, and validate if validation is enabled - exclude = [] - if validate: - try: - result = schema.parse_obj(validation) - except ValidationError as e: - raise ConfigValidationError( - config=config, errors=e.errors(), parent=parent - ) from None - else: - # Same as parse_obj, but without validation - result = schema.construct(**validation) - # If our schema doesn't allow extra values, we need to filter them - # manually because .construct doesn't parse anything - if schema.Config.extra in (Extra.forbid, Extra.ignore): - fields = schema.__fields__.keys() - exclude = [k for k in result.__fields_set__ if k not in fields] - exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()]) - validation.update(result.dict(exclude=exclude_validation)) - filled, final = cls._update_from_parsed(validation, filled, final) - if exclude: - filled = {k: v for k, v in filled.items() if k not in exclude} - validation = {k: v for k, v in validation.items() if k not in exclude} - final = {k: v for k, v in final.items() if k not in exclude} - return filled, validation, final - - @classmethod - def _update_from_parsed( - cls, validation: Dict[str, Any], filled: Dict[str, Any], final: Dict[str, Any] - ): - """Update the final result with the parsed config like converted - values recursively. - """ - for key, value in validation.items(): - if key in RESERVED_FIELDS.values(): - continue # skip aliases for reserved fields - if key not in filled: - filled[key] = value - if key not in final: - final[key] = value - if isinstance(value, dict): - filled[key], final[key] = cls._update_from_parsed( - value, filled[key], final[key] - ) - # Update final config with parsed value if they're not equal (in - # value and in type) but not if it's a generator because we had to - # replace that to validate it correctly - elif key == ARGS_FIELD: - continue # don't substitute if list of positional args - elif isinstance(value, numpy.ndarray): # check numpy first, just in case - final[key] = value - elif ( - value != final[key] or not isinstance(type(value), type(final[key])) - ) and not isinstance(final[key], GeneratorType): - final[key] = value - return filled, final - - @classmethod - def _validate_overrides(cls, filled: Config, overrides: Dict[str, Any]): - """Validate overrides against a filled config to make sure there are - no references to properties that don't exist and weren't used.""" - error_msg = "Invalid override: config value doesn't exist" - errors = [] - for override_key in overrides.keys(): - if not cls._is_in_config(override_key, filled): - errors.append({"msg": error_msg, "loc": [override_key]}) - if errors: - raise ConfigValidationError(config=filled, errors=errors) - - @classmethod - def _is_in_config(cls, prop: str, config: Union[Dict[str, Any], Config]): - """Check whether a nested config property like "section.subsection.key" - is in a given config.""" - tree = prop.split(".") - obj = dict(config) - while tree: - key = tree.pop(0) - if isinstance(obj, dict) and key in obj: - obj = obj[key] - else: - return False - return True - - @classmethod - def is_promise(cls, obj: Any) -> bool: - """Check whether an object is a "promise", i.e. contains a reference - to a registered function (via a key starting with `"@"`. - """ - if not hasattr(obj, "keys"): - return False - id_keys = [k for k in obj.keys() if k.startswith("@")] - if len(id_keys): - return True - return False - - @classmethod - def get_constructor(cls, obj: Dict[str, Any]) -> Tuple[str, str]: - id_keys = [k for k in obj.keys() if k.startswith("@")] - if len(id_keys) != 1: - err_msg = f"A block can only contain one function registry reference. Got: {id_keys}" - raise ConfigValidationError(config=obj, errors=[{"msg": err_msg}]) - else: - key = id_keys[0] - value = obj[key] - return (key[1:], value) - - @classmethod - def parse_args(cls, obj: Dict[str, Any]) -> Tuple[List[Any], Dict[str, Any]]: - args = [] - kwargs = {} - for key, value in obj.items(): - if not key.startswith("@"): - if key == ARGS_FIELD: - args = value - elif key in RESERVED_FIELDS.values(): - continue - else: - kwargs[key] = value - return args, kwargs - - @classmethod - def make_promise_schema( - cls, obj: Dict[str, Any], *, resolve: bool = True - ) -> Type[BaseModel]: - """Create a schema for a promise dict (referencing a registry function) - by inspecting the function signature. - """ - reg_name, func_name = cls.get_constructor(obj) - if not resolve and not cls.has(reg_name, func_name): - return EmptySchema - func = cls.get(reg_name, func_name) - # Read the argument annotations and defaults from the function signature - id_keys = [k for k in obj.keys() if k.startswith("@")] - sig_args: Dict[str, Any] = {id_keys[0]: (str, ...)} - for param in inspect.signature(func).parameters.values(): - # If no annotation is specified assume it's anything - annotation = param.annotation if param.annotation != param.empty else Any - # If no default value is specified assume that it's required - default = param.default if param.default != param.empty else ... - # Handle spread arguments and use their annotation as Sequence[whatever] - if param.kind == param.VAR_POSITIONAL: - spread_annot = Sequence[annotation] # type: ignore - sig_args[ARGS_FIELD_ALIAS] = (spread_annot, default) - else: - name = RESERVED_FIELDS.get(param.name, param.name) - sig_args[name] = (annotation, default) - sig_args["__config__"] = _PromiseSchemaConfig - return create_model("ArgModel", **sig_args) - __all__ = ["Config", "registry", "ConfigValidationError"] diff --git a/thinc/tests/test_config.py b/thinc/tests/test_config.py index ddd05ca96..fcba87f4b 100644 --- a/thinc/tests/test_config.py +++ b/thinc/tests/test_config.py @@ -135,360 +135,6 @@ def catsie_v2(evil: StrictBool, cute: bool = True, cute_level: int = 1) -> str: worst_catsie = {"@cats": "catsie.v1", "evil": True, "cute": False} -def test_validate_simple_config(): - simple_config = {"hello": 1, "world": 2} - f, _, v = my_registry._fill(simple_config, HelloIntsSchema) - assert f == simple_config - assert v == simple_config - - -def test_invalidate_simple_config(): - invalid_config = {"hello": 1, "world": "hi!"} - with pytest.raises(ConfigValidationError) as exc_info: - my_registry._fill(invalid_config, HelloIntsSchema) - error = exc_info.value - assert len(error.errors) == 1 - assert "type_error.integer" in error.error_types - - -def test_invalidate_extra_args(): - invalid_config = {"hello": 1, "world": 2, "extra": 3} - with pytest.raises(ConfigValidationError): - my_registry._fill(invalid_config, HelloIntsSchema) - - -def test_fill_defaults_simple_config(): - valid_config = {"required": 1} - filled, _, v = my_registry._fill(valid_config, DefaultsSchema) - assert filled["required"] == 1 - assert filled["optional"] == "default value" - invalid_config = {"optional": "some value"} - with pytest.raises(ConfigValidationError): - my_registry._fill(invalid_config, DefaultsSchema) - - -def test_fill_recursive_config(): - valid_config = {"outer_req": 1, "level2_req": {"hello": 4, "world": 7}} - filled, _, validation = my_registry._fill(valid_config, ComplexSchema) - assert filled["outer_req"] == 1 - assert filled["outer_opt"] == "default value" - assert filled["level2_req"]["hello"] == 4 - assert filled["level2_req"]["world"] == 7 - assert filled["level2_opt"]["required"] == 1 - assert filled["level2_opt"]["optional"] == "default value" - - -def test_is_promise(): - assert my_registry.is_promise(good_catsie) - assert not my_registry.is_promise({"hello": "world"}) - assert not my_registry.is_promise(1) - invalid = {"@complex": "complex.v1", "rate": 1.0, "@cats": "catsie.v1"} - assert my_registry.is_promise(invalid) - - -def test_get_constructor(): - my_registry.get_constructor(good_catsie) == ("cats", "catsie.v1") - - -def test_parse_args(): - args, kwargs = my_registry.parse_args(bad_catsie) - assert args == [] - assert kwargs == {"evil": True, "cute": True} - - -def test_make_promise_schema(): - schema = my_registry.make_promise_schema(good_catsie) - assert "evil" in schema.__fields__ - assert "cute" in schema.__fields__ - - -def test_validate_promise(): - config = {"required": 1, "optional": good_catsie} - filled, _, validated = my_registry._fill(config, DefaultsSchema) - assert filled == config - assert validated == {"required": 1, "optional": "meow"} - - -def test_fill_validate_promise(): - config = {"required": 1, "optional": {"@cats": "catsie.v1", "evil": False}} - filled, _, validated = my_registry._fill(config, DefaultsSchema) - assert filled["optional"]["cute"] is True - - -def test_fill_invalidate_promise(): - config = {"required": 1, "optional": {"@cats": "catsie.v1", "evil": False}} - with pytest.raises(ConfigValidationError): - my_registry._fill(config, HelloIntsSchema) - config["optional"]["whiskers"] = True - with pytest.raises(ConfigValidationError): - my_registry._fill(config, DefaultsSchema) - - -def test_create_registry(): - with pytest.raises(ValueError): - my_registry.create("cats") - my_registry.create("dogs") - assert hasattr(my_registry, "dogs") - assert len(my_registry.dogs.get_all()) == 0 - my_registry.dogs.register("good_boy.v1", func=lambda x: x) - assert len(my_registry.dogs.get_all()) == 1 - with pytest.raises(ValueError): - my_registry.create("dogs") - - -def test_registry_methods(): - with pytest.raises(ValueError): - my_registry.get("dfkoofkds", "catsie.v1") - my_registry.cats.register("catsie.v123")(None) - with pytest.raises(ValueError): - my_registry.get("cats", "catsie.v123") - - -def test_resolve_no_schema(): - config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}} - result = my_registry.resolve({"cfg": config})["cfg"] - assert result["one"] == 1 - assert result["two"] == {"three": "scratch!"} - with pytest.raises(ConfigValidationError): - config = {"two": {"three": {"@cats": "catsie.v1", "evil": "true"}}} - my_registry.resolve(config) - - -def test_resolve_schema(): - class TestBaseSubSchema(BaseModel): - three: str - - class TestBaseSchema(BaseModel): - one: PositiveInt - two: TestBaseSubSchema - - class Config: - extra = "forbid" - - class TestSchema(BaseModel): - cfg: TestBaseSchema - - config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}} - my_registry.resolve({"cfg": config}, schema=TestSchema) - config = {"one": -1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}} - with pytest.raises(ConfigValidationError): - # "one" is not a positive int - my_registry.resolve({"cfg": config}, schema=TestSchema) - config = {"one": 1, "two": {"four": {"@cats": "catsie.v1", "evil": True}}} - with pytest.raises(ConfigValidationError): - # "three" is required in subschema - my_registry.resolve({"cfg": config}, schema=TestSchema) - - -def test_resolve_schema_coerced(): - class TestBaseSchema(BaseModel): - test1: str - test2: bool - test3: float - - class TestSchema(BaseModel): - cfg: TestBaseSchema - - config = {"test1": 123, "test2": 1, "test3": 5} - filled = my_registry.fill({"cfg": config}, schema=TestSchema) - result = my_registry.resolve({"cfg": config}, schema=TestSchema) - assert result["cfg"] == {"test1": "123", "test2": True, "test3": 5.0} - # This only affects the resolved config, not the filled config - assert filled["cfg"] == config - - -def test_read_config(): - byte_string = EXAMPLE_CONFIG.encode("utf8") - cfg = Config().from_bytes(byte_string) - - assert cfg["optimizer"]["beta1"] == 0.9 - assert cfg["optimizer"]["learn_rate"]["initial_rate"] == 0.1 - assert cfg["pipeline"]["parser"]["factory"] == "parser" - assert cfg["pipeline"]["parser"]["model"]["tok2vec"]["width"] == 128 - - -def test_optimizer_config(): - cfg = Config().from_str(OPTIMIZER_CFG) - optimizer = my_registry.resolve(cfg, validate=True)["optimizer"] - assert optimizer.b1 == 0.9 - - -def test_config_to_str(): - cfg = Config().from_str(OPTIMIZER_CFG) - assert cfg.to_str().strip() == OPTIMIZER_CFG.strip() - cfg = Config({"optimizer": {"foo": "bar"}}).from_str(OPTIMIZER_CFG) - assert cfg.to_str().strip() == OPTIMIZER_CFG.strip() - - -def test_config_to_str_creates_intermediate_blocks(): - cfg = Config({"optimizer": {"foo": {"bar": 1}}}) - assert ( - cfg.to_str().strip() - == """ -[optimizer] - -[optimizer.foo] -bar = 1 - """.strip() - ) - - -def test_config_roundtrip_bytes(): - cfg = Config().from_str(OPTIMIZER_CFG) - cfg_bytes = cfg.to_bytes() - new_cfg = Config().from_bytes(cfg_bytes) - assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip() - - -def test_config_roundtrip_disk(): - cfg = Config().from_str(OPTIMIZER_CFG) - with make_tempdir() as path: - cfg_path = path / "config.cfg" - cfg.to_disk(cfg_path) - new_cfg = Config().from_disk(cfg_path) - assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip() - - -def test_config_roundtrip_disk_respects_path_subclasses(pathy_fixture): - cfg = Config().from_str(OPTIMIZER_CFG) - cfg_path = pathy_fixture / "config.cfg" - cfg.to_disk(cfg_path) - new_cfg = Config().from_disk(cfg_path) - assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip() - - -def test_config_to_str_invalid_defaults(): - """Test that an error is raised if a config contains top-level keys without - a section that would otherwise be interpreted as [DEFAULT] (which causes - the values to be included in *all* other sections). - """ - cfg = {"one": 1, "two": {"@cats": "catsie.v1", "evil": "hello"}} - with pytest.raises(ConfigValidationError): - Config(cfg).to_str() - config_str = "[DEFAULT]\none = 1" - with pytest.raises(ConfigValidationError): - Config().from_str(config_str) - - -def test_validation_custom_types(): - def complex_args( - rate: StrictFloat, - steps: PositiveInt = 10, # type: ignore - log_level: constr(regex="(DEBUG|INFO|WARNING|ERROR)") = "ERROR", - ): - return None - - my_registry.create("complex") - my_registry.complex("complex.v1")(complex_args) - cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "INFO"} - my_registry.resolve({"config": cfg}) - cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": -1, "log_level": "INFO"} - with pytest.raises(ConfigValidationError): - # steps is not a positive int - my_registry.resolve({"config": cfg}) - cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "none"} - with pytest.raises(ConfigValidationError): - # log_level is not a string matching the regex - my_registry.resolve({"config": cfg}) - cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "INFO"} - with pytest.raises(ConfigValidationError): - # top-level object is promise - my_registry.resolve(cfg) - with pytest.raises(ConfigValidationError): - # top-level object is promise - my_registry.fill(cfg) - cfg = {"@complex": "complex.v1", "rate": 1.0, "@cats": "catsie.v1"} - with pytest.raises(ConfigValidationError): - # two constructors - my_registry.resolve({"config": cfg}) - - -def test_validation_no_validate(): - config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": "false"}}} - result = my_registry.resolve({"cfg": config}, validate=False) - filled = my_registry.fill({"cfg": config}, validate=False) - assert result["cfg"]["one"] == 1 - assert result["cfg"]["two"] == {"three": "scratch!"} - assert filled["cfg"]["two"]["three"]["evil"] == "false" - assert filled["cfg"]["two"]["three"]["cute"] is True - - -def test_validation_fill_defaults(): - config = {"cfg": {"one": 1, "two": {"@cats": "catsie.v1", "evil": "hello"}}} - result = my_registry.fill(config, validate=False) - assert len(result["cfg"]["two"]) == 3 - with pytest.raises(ConfigValidationError): - # Required arg "evil" is not defined - my_registry.fill(config) - config = {"cfg": {"one": 1, "two": {"@cats": "catsie.v2", "evil": False}}} - # Fill in with new defaults - result = my_registry.fill(config) - assert len(result["cfg"]["two"]) == 4 - assert result["cfg"]["two"]["evil"] is False - assert result["cfg"]["two"]["cute"] is True - assert result["cfg"]["two"]["cute_level"] == 1 - - -def test_make_config_positional_args(): - @my_registry.cats("catsie.v567") - def catsie_567(*args: Optional[str], foo: str = "bar"): - assert args[0] == "^_^" - assert args[1] == "^(*.*)^" - assert foo == "baz" - return args[0] - - args = ["^_^", "^(*.*)^"] - cfg = {"config": {"@cats": "catsie.v567", "foo": "baz", "*": args}} - assert my_registry.resolve(cfg)["config"] == "^_^" - - -def test_make_config_positional_args_complex(): - @my_registry.cats("catsie.v890") - def catsie_890(*args: Optional[Union[StrictBool, PositiveInt]]): - assert args[0] == 123 - return args[0] - - cfg = {"config": {"@cats": "catsie.v890", "*": [123, True, 1, False]}} - assert my_registry.resolve(cfg)["config"] == 123 - cfg = {"config": {"@cats": "catsie.v890", "*": [123, "True"]}} - with pytest.raises(ConfigValidationError): - # "True" is not a valid boolean or positive int - my_registry.resolve(cfg) - - -def test_positional_args_to_from_string(): - cfg = """[a]\nb = 1\n* = ["foo","bar"]""" - assert Config().from_str(cfg).to_str() == cfg - cfg = """[a]\nb = 1\n\n[a.*.bar]\ntest = 2\n\n[a.*.foo]\ntest = 1""" - assert Config().from_str(cfg).to_str() == cfg - - @my_registry.cats("catsie.v666") - def catsie_666(*args, meow=False): - return args - - cfg = """[a]\n@cats = "catsie.v666"\n* = ["foo","bar"]""" - filled = my_registry.fill(Config().from_str(cfg)).to_str() - assert filled == """[a]\n@cats = "catsie.v666"\n* = ["foo","bar"]\nmeow = false""" - resolved = my_registry.resolve(Config().from_str(cfg)) - assert resolved == {"a": ("foo", "bar")} - cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\nx = 1""" - filled = my_registry.fill(Config().from_str(cfg)).to_str() - assert filled == """[a]\n@cats = "catsie.v666"\nmeow = false\n\n[a.*.foo]\nx = 1""" - resolved = my_registry.resolve(Config().from_str(cfg)) - assert resolved == {"a": ({"x": 1},)} - - @my_registry.cats("catsie.v777") - def catsie_777(y: int = 1): - return "meow" * y - - cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\n@cats = "catsie.v777\"""" - filled = my_registry.fill(Config().from_str(cfg)).to_str() - expected = """[a]\n@cats = "catsie.v666"\nmeow = false\n\n[a.*.foo]\n@cats = "catsie.v777"\ny = 1""" - assert filled == expected - cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\n@cats = "catsie.v777"\ny = 3""" - result = my_registry.resolve(Config().from_str(cfg)) - assert result == {"a": ("meowmeowmeow",)} - def test_make_config_positional_args_dicts(): cfg = { @@ -511,51 +157,6 @@ def test_make_config_positional_args_dicts(): model.finish_update(resolved["optimizer"]) -def test_validation_generators_iterable(): - @my_registry.optimizers("test_optimizer.v1") - def test_optimizer_v1(rate: float) -> None: - return None - - @my_registry.schedules("test_schedule.v1") - def test_schedule_v1(some_value: float = 1.0) -> Iterable[float]: - while True: - yield some_value - - config = {"optimizer": {"@optimizers": "test_optimizer.v1", "rate": 0.1}} - my_registry.resolve(config) - - -def test_validation_unset_type_hints(): - """Test that unset type hints are handled correctly (and treated as Any).""" - - @my_registry.optimizers("test_optimizer.v2") - def test_optimizer_v2(rate, steps: int = 10) -> None: - return None - - config = {"test": {"@optimizers": "test_optimizer.v2", "rate": 0.1, "steps": 20}} - my_registry.resolve(config) - - -def test_validation_bad_function(): - @my_registry.optimizers("bad.v1") - def bad() -> None: - raise ValueError("This is an error in the function") - return None - - @my_registry.optimizers("good.v1") - def good() -> None: - return None - - # Bad function - config = {"test": {"@optimizers": "bad.v1"}} - with pytest.raises(ValueError): - my_registry.resolve(config) - # Bad function call - config = {"test": {"@optimizers": "good.v1", "invalid_arg": 1}} - with pytest.raises(ConfigValidationError): - my_registry.resolve(config) - - def test_objects_from_config(): config = { "optimizer": { @@ -583,93 +184,6 @@ def decaying(base_rate: float, repeat: int) -> List[float]: assert optimizer.learn_rate == 0.001 -def test_partials_from_config(): - """Test that functions registered with partial applications are handled - correctly (e.g. initializers).""" - name = "uniform_init.v1" - cfg = {"test": {"@initializers": name, "lo": -0.2}} - func = my_registry.resolve(cfg)["test"] - assert hasattr(func, "__call__") - # The partial will still have lo as an arg, just with default - assert len(inspect.signature(func).parameters) == 4 - # Make sure returned partial function has correct value set - assert inspect.signature(func).parameters["lo"].default == -0.2 - # Actually call the function and verify - func(NumpyOps(), (2, 3)) - # Make sure validation still works - bad_cfg = {"test": {"@initializers": name, "lo": [0.5]}} - with pytest.raises(ConfigValidationError): - my_registry.resolve(bad_cfg) - bad_cfg = {"test": {"@initializers": name, "lo": -0.2, "other": 10}} - with pytest.raises(ConfigValidationError): - my_registry.resolve(bad_cfg) - - -def test_partials_from_config_nested(): - """Test that partial functions are passed correctly to other registered - functions that consume them (e.g. initializers -> layers).""" - - def test_initializer(a: int, b: int = 1) -> int: - return a * b - - @my_registry.initializers("test_initializer.v1") - def configure_test_initializer(b: int = 1) -> Callable[[int], int]: - return partial(test_initializer, b=b) - - @my_registry.layers("test_layer.v1") - def test_layer(init: Callable[[int], int], c: int = 1) -> Callable[[int], int]: - return lambda x: x + init(c) - - cfg = { - "@layers": "test_layer.v1", - "c": 5, - "init": {"@initializers": "test_initializer.v1", "b": 10}, - } - func = my_registry.resolve({"test": cfg})["test"] - assert func(1) == 51 - assert func(100) == 150 - - -def test_validate_generator(): - """Test that generator replacement for validation in config doesn't - actually replace the returned value.""" - - @my_registry.schedules("test_schedule.v2") - def test_schedule(): - while True: - yield 10 - - cfg = {"@schedules": "test_schedule.v2"} - result = my_registry.resolve({"test": cfg})["test"] - assert isinstance(result, GeneratorType) - - @my_registry.optimizers("test_optimizer.v2") - def test_optimizer2(rate: Generator) -> Generator: - return rate - - cfg = { - "@optimizers": "test_optimizer.v2", - "rate": {"@schedules": "test_schedule.v2"}, - } - result = my_registry.resolve({"test": cfg})["test"] - assert isinstance(result, GeneratorType) - - @my_registry.optimizers("test_optimizer.v3") - def test_optimizer3(schedules: Dict[str, Generator]) -> Generator: - return schedules["rate"] - - cfg = { - "@optimizers": "test_optimizer.v3", - "schedules": {"rate": {"@schedules": "test_schedule.v2"}}, - } - result = my_registry.resolve({"test": cfg})["test"] - assert isinstance(result, GeneratorType) - - @my_registry.optimizers("test_optimizer.v4") - def test_optimizer4(*schedules: Generator) -> Generator: - return schedules[0] - - def test_handle_generic_model_type(): """Test that validation can handle checks against arbitrary generic types in function argument annotations.""" @@ -685,760 +199,6 @@ def my_transform(model: Model[int, int]): assert model.name == "transformed_model" -@pytest.mark.parametrize( - "cfg", - [ - "[a]\nb = 1\nc = 2\n\n[a.c]\nd = 3", - "[a]\nb = 1\n\n[a.c]\nd = 2\n\n[a.c.d]\ne = 3", - ], -) -def test_handle_error_duplicate_keys(cfg): - """This would cause very cryptic error when interpreting config. - (TypeError: 'X' object does not support item assignment) - """ - with pytest.raises(ConfigValidationError): - Config().from_str(cfg) - - -@pytest.mark.parametrize( - "cfg,is_valid", - [("[a]\nb = 1\n\n[a.c]\nd = 3", True), ("[a]\nb = 1\n\n[A.c]\nd = 2", False)], -) -def test_cant_expand_undefined_block(cfg, is_valid): - """Test that you can't expand a block that hasn't been created yet. This - comes up when you typo a name, and if we allow expansion of undefined blocks, - it's very hard to create good errors for those typos. - """ - if is_valid: - Config().from_str(cfg) - else: - with pytest.raises(ConfigValidationError): - Config().from_str(cfg) - - -def test_fill_config_overrides(): - config = { - "cfg": { - "one": 1, - "two": {"three": {"@cats": "catsie.v1", "evil": True, "cute": False}}, - } - } - overrides = {"cfg.two.three.evil": False} - result = my_registry.fill(config, overrides=overrides, validate=True) - assert result["cfg"]["two"]["three"]["evil"] is False - # Test that promises can be overwritten as well - overrides = {"cfg.two.three": 3} - result = my_registry.fill(config, overrides=overrides, validate=True) - assert result["cfg"]["two"]["three"] == 3 - # Test that value can be overwritten with promises and that the result is - # interpreted and filled correctly - overrides = {"cfg": {"one": {"@cats": "catsie.v1", "evil": False}, "two": None}} - result = my_registry.fill(config, overrides=overrides) - assert result["cfg"]["two"] is None - assert result["cfg"]["one"]["@cats"] == "catsie.v1" - assert result["cfg"]["one"]["evil"] is False - assert result["cfg"]["one"]["cute"] is True - # Overwriting with wrong types should cause validation error - with pytest.raises(ConfigValidationError): - overrides = {"cfg.two.three.evil": 20} - my_registry.fill(config, overrides=overrides, validate=True) - # Overwriting with incomplete promises should cause validation error - with pytest.raises(ConfigValidationError): - overrides = {"cfg": {"one": {"@cats": "catsie.v1"}, "two": None}} - my_registry.fill(config, overrides=overrides) - # Overrides that don't match config should raise error - with pytest.raises(ConfigValidationError): - overrides = {"cfg.two.three.evil": False, "two.four": True} - my_registry.fill(config, overrides=overrides, validate=True) - with pytest.raises(ConfigValidationError): - overrides = {"cfg.five": False} - my_registry.fill(config, overrides=overrides, validate=True) - - -def test_resolve_overrides(): - config = { - "cfg": { - "one": 1, - "two": {"three": {"@cats": "catsie.v1", "evil": True, "cute": False}}, - } - } - overrides = {"cfg.two.three.evil": False} - result = my_registry.resolve(config, overrides=overrides, validate=True) - assert result["cfg"]["two"]["three"] == "meow" - # Test that promises can be overwritten as well - overrides = {"cfg.two.three": 3} - result = my_registry.resolve(config, overrides=overrides, validate=True) - assert result["cfg"]["two"]["three"] == 3 - # Test that value can be overwritten with promises - overrides = {"cfg": {"one": {"@cats": "catsie.v1", "evil": False}, "two": None}} - result = my_registry.resolve(config, overrides=overrides) - assert result["cfg"]["one"] == "meow" - assert result["cfg"]["two"] is None - # Overwriting with wrong types should cause validation error - with pytest.raises(ConfigValidationError): - overrides = {"cfg.two.three.evil": 20} - my_registry.resolve(config, overrides=overrides, validate=True) - # Overwriting with incomplete promises should cause validation error - with pytest.raises(ConfigValidationError): - overrides = {"cfg": {"one": {"@cats": "catsie.v1"}, "two": None}} - my_registry.resolve(config, overrides=overrides) - # Overrides that don't match config should raise error - with pytest.raises(ConfigValidationError): - overrides = {"cfg.two.three.evil": False, "cfg.two.four": True} - my_registry.resolve(config, overrides=overrides, validate=True) - with pytest.raises(ConfigValidationError): - overrides = {"cfg.five": False} - my_registry.resolve(config, overrides=overrides, validate=True) - - -@pytest.mark.parametrize( - "prop,expected", - [("a.b.c", True), ("a.b", True), ("a", True), ("a.e", True), ("a.b.c.d", False)], -) -def test_is_in_config(prop, expected): - config = {"a": {"b": {"c": 5, "d": 6}, "e": [1, 2]}} - assert my_registry._is_in_config(prop, config) is expected - - -def test_resolve_prefilled_values(): - class Language(object): - def __init__(self): - ... - - @my_registry.optimizers("prefilled.v1") - def prefilled(nlp: Language, value: int = 10): - return (nlp, value) - - # Passing an instance of Language here via the config is bad, since it - # won't serialize to a string, but we still test for it - config = {"test": {"@optimizers": "prefilled.v1", "nlp": Language(), "value": 50}} - resolved = my_registry.resolve(config, validate=True) - result = resolved["test"] - assert isinstance(result[0], Language) - assert result[1] == 50 - - -def test_fill_config_dict_return_type(): - """Test that a registered function returning a dict is handled correctly.""" - - @my_registry.cats.register("catsie_with_dict.v1") - def catsie_with_dict(evil: StrictBool) -> Dict[str, bool]: - return {"not_evil": not evil} - - config = {"test": {"@cats": "catsie_with_dict.v1", "evil": False}, "foo": 10} - result = my_registry.fill({"cfg": config}, validate=True)["cfg"]["test"] - assert result["evil"] is False - assert "not_evil" not in result - result = my_registry.resolve({"cfg": config}, validate=True)["cfg"]["test"] - assert result["not_evil"] is True - - -def test_deepcopy_config(): - config = Config({"a": 1, "b": {"c": 2, "d": 3}}) - copied = config.copy() - # Same values but not same object - assert config == copied - assert config is not copied - # Check for error if value can't be pickled/deepcopied - config = Config({"a": 1, "b": numpy}) - with pytest.raises(ValueError): - config.copy() - - -def test_config_to_str_simple_promises(): - """Test that references to function registries without arguments are - serialized inline as dict.""" - config_str = """[section]\nsubsection = {"@registry":"value"}""" - config = Config().from_str(config_str) - assert config["section"]["subsection"]["@registry"] == "value" - assert config.to_str() == config_str - - -def test_config_from_str_invalid_section(): - config_str = """[a]\nb = null\n\n[a.b]\nc = 1""" - with pytest.raises(ConfigValidationError): - Config().from_str(config_str) - - config_str = """[a]\nb = null\n\n[a.b.c]\nd = 1""" - with pytest.raises(ConfigValidationError): - Config().from_str(config_str) - - -def test_config_to_str_order(): - """Test that Config.to_str orders the sections.""" - config = {"a": {"b": {"c": 1, "d": 2}, "e": 3}, "f": {"g": {"h": {"i": 4, "j": 5}}}} - expected = ( - "[a]\ne = 3\n\n[a.b]\nc = 1\nd = 2\n\n[f]\n\n[f.g]\n\n[f.g.h]\ni = 4\nj = 5" - ) - config = Config(config) - assert config.to_str() == expected - - -@pytest.mark.parametrize("d", [".", ":"]) -def test_config_interpolation(d): - """Test that config values are interpolated correctly. The parametrized - value is the final divider (${a.b} vs. ${a:b}). Both should now work and be - valid. The double {{ }} in the config strings are required to prevent the - references from being interpreted as an actual f-string variable. - """ - c_str = """[a]\nfoo = "hello"\n\n[b]\nbar = ${foo}""" - with pytest.raises(ConfigValidationError): - Config().from_str(c_str) - c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = ${{a{d}foo}}""" - assert Config().from_str(c_str)["b"]["bar"] == "hello" - c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = ${{a{d}foo}}!""" - assert Config().from_str(c_str)["b"]["bar"] == "hello!" - c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = "${{a{d}foo}}!\"""" - assert Config().from_str(c_str)["b"]["bar"] == "hello!" - c_str = f"""[a]\nfoo = 15\n\n[b]\nbar = ${{a{d}foo}}!""" - assert Config().from_str(c_str)["b"]["bar"] == "15!" - c_str = f"""[a]\nfoo = ["x", "y"]\n\n[b]\nbar = ${{a{d}foo}}""" - assert Config().from_str(c_str)["b"]["bar"] == ["x", "y"] - # Interpolation within the same section - c_str = f"""[a]\nfoo = "x"\nbar = ${{a{d}foo}}\nbaz = "${{a{d}foo}}y\"""" - assert Config().from_str(c_str)["a"]["bar"] == "x" - assert Config().from_str(c_str)["a"]["baz"] == "xy" - - -def test_config_interpolation_lists(): - # Test that lists are preserved correctly - c_str = """[a]\nb = 1\n\n[c]\nd = ["hello ${a.b}", "world"]""" - config = Config().from_str(c_str, interpolate=False) - assert config["c"]["d"] == ["hello ${a.b}", "world"] - config = config.interpolate() - assert config["c"]["d"] == ["hello 1", "world"] - c_str = """[a]\nb = 1\n\n[c]\nd = [${a.b}, "hello ${a.b}", "world"]""" - config = Config().from_str(c_str) - assert config["c"]["d"] == [1, "hello 1", "world"] - config = Config().from_str(c_str, interpolate=False) - # NOTE: This currently doesn't work, because we can't know how to JSON-load - # the uninterpolated list [${a.b}]. - # assert config["c"]["d"] == ["${a.b}", "hello ${a.b}", "world"] - # config = config.interpolate() - # assert config["c"]["d"] == [1, "hello 1", "world"] - c_str = """[a]\nb = 1\n\n[c]\nd = ["hello", ${a}]""" - config = Config().from_str(c_str) - assert config["c"]["d"] == ["hello", {"b": 1}] - c_str = """[a]\nb = 1\n\n[c]\nd = ["hello", "hello ${a}"]""" - with pytest.raises(ConfigValidationError): - Config().from_str(c_str) - config_str = """[a]\nb = 1\n\n[c]\nd = ["hello", {"x": ["hello ${a.b}"], "y": 2}]""" - config = Config().from_str(config_str) - assert config["c"]["d"] == ["hello", {"x": ["hello 1"], "y": 2}] - config_str = """[a]\nb = 1\n\n[c]\nd = ["hello", {"x": [${a.b}], "y": 2}]""" - with pytest.raises(ConfigValidationError): - Config().from_str(c_str) - - -@pytest.mark.parametrize("d", [".", ":"]) -def test_config_interpolation_sections(d): - """Test that config sections are interpolated correctly. The parametrized - value is the final divider (${a.b} vs. ${a:b}). Both should now work and be - valid. The double {{ }} in the config strings are required to prevent the - references from being interpreted as an actual f-string variable. - """ - # Simple block references - c_str = """[a]\nfoo = "hello"\nbar = "world"\n\n[b]\nc = ${a}""" - config = Config().from_str(c_str) - assert config["b"]["c"] == config["a"] - # References with non-string values - c_str = f"""[a]\nfoo = "hello"\n\n[a.x]\ny = ${{a{d}b}}\n\n[a.b]\nc = 1\nd = [10]""" - config = Config().from_str(c_str) - assert config["a"]["x"]["y"] == config["a"]["b"] - # Multiple references in the same string - c_str = f"""[a]\nx = "string"\ny = 10\n\n[b]\nz = "${{a{d}x}}/${{a{d}y}}\"""" - config = Config().from_str(c_str) - assert config["b"]["z"] == "string/10" - # Non-string references in string (converted to string) - c_str = f"""[a]\nx = ["hello", "world"]\n\n[b]\ny = "result: ${{a{d}x}}\"""" - config = Config().from_str(c_str) - assert config["b"]["y"] == 'result: ["hello", "world"]' - # References to sections referencing sections - c_str = """[a]\nfoo = "x"\n\n[b]\nbar = ${a}\n\n[c]\nbaz = ${b}""" - config = Config().from_str(c_str) - assert config["b"]["bar"] == config["a"] - assert config["c"]["baz"] == config["b"] - # References to section values referencing other sections - c_str = f"""[a]\nfoo = "x"\n\n[b]\nbar = ${{a}}\n\n[c]\nbaz = ${{b{d}bar}}""" - config = Config().from_str(c_str) - assert config["c"]["baz"] == config["b"]["bar"] - # References to sections with subsections - c_str = """[a]\nfoo = "x"\n\n[a.b]\nbar = 100\n\n[c]\nbaz = ${a}""" - config = Config().from_str(c_str) - assert config["c"]["baz"] == config["a"] - # Infinite recursion - c_str = """[a]\nfoo ="x"\n\n[a.b]\nbar = ${a}""" - config = Config().from_str(c_str) - assert config["a"]["b"]["bar"] == config["a"] - c_str = f"""[a]\nfoo = "x"\n\n[b]\nbar = ${{a}}\n\n[c]\nbaz = ${{b.bar{d}foo}}""" - # We can't reference not-yet interpolated subsections - with pytest.raises(ConfigValidationError): - Config().from_str(c_str) - # Generally invalid references - c_str = f"""[a]\nfoo = ${{b{d}bar}}""" - with pytest.raises(ConfigValidationError): - Config().from_str(c_str) - # We can't reference sections or promises within strings - c_str = """[a]\n\n[a.b]\nfoo = "x: ${c}"\n\n[c]\nbar = 1\nbaz = 2""" - with pytest.raises(ConfigValidationError): - Config().from_str(c_str) - - -def test_config_from_str_overrides(): - config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\ne = 3\n\n[f]\ng = {"x": "y"}""" - # Basic value substitution - overrides = {"a.b": 10, "a.c.d": 20} - config = Config().from_str(config_str, overrides=overrides) - assert config["a"]["b"] == 10 - assert config["a"]["c"]["d"] == 20 - assert config["a"]["c"]["e"] == 3 - # Valid values that previously weren't in config - config = Config().from_str(config_str, overrides={"a.c.f": 100}) - assert config["a"]["c"]["d"] == 2 - assert config["a"]["c"]["e"] == 3 - assert config["a"]["c"]["f"] == 100 - # Invalid keys and sections - with pytest.raises(ConfigValidationError): - Config().from_str(config_str, overrides={"f": 10}) - # This currently isn't expected to work, because the dict in f.g is not - # interpreted as a section while the config is still just the configparser - with pytest.raises(ConfigValidationError): - Config().from_str(config_str, overrides={"f.g.x": "z"}) - # With variables (values) - config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\ne = ${a:b}""" - config = Config().from_str(config_str, overrides={"a.b": 10}) - assert config["a"]["b"] == 10 - assert config["a"]["c"]["e"] == 10 - # With variables (sections) - config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\n[e]\nf = ${a.c}""" - config = Config().from_str(config_str, overrides={"a.c.d": 20}) - assert config["a"]["c"]["d"] == 20 - assert config["e"]["f"] == {"d": 20} - - -def test_config_reserved_aliases(): - """Test that the auto-generated pydantic schemas auto-alias reserved - attributes like "validate" that would otherwise cause NameError.""" - - @my_registry.cats("catsie.with_alias") - def catsie_with_alias(validate: StrictBool = False): - return validate - - cfg = {"@cats": "catsie.with_alias", "validate": True} - resolved = my_registry.resolve({"test": cfg}) - filled = my_registry.fill({"test": cfg}) - assert resolved["test"] is True - assert filled["test"] == cfg - cfg = {"@cats": "catsie.with_alias", "validate": 20} - with pytest.raises(ConfigValidationError): - my_registry.resolve({"test": cfg}) - - -@pytest.mark.parametrize("d", [".", ":"]) -def test_config_no_interpolation(d): - """Test that interpolation is correctly preserved. The parametrized - value is the final divider (${a.b} vs. ${a:b}). Both should now work and be - valid. The double {{ }} in the config strings are required to prevent the - references from being interpreted as an actual f-string variable. - """ - c_str = f"""[a]\nb = 1\n\n[c]\nd = ${{a{d}b}}\ne = \"hello${{a{d}b}}"\nf = ${{a}}""" - config = Config().from_str(c_str, interpolate=False) - assert not config.is_interpolated - assert config["c"]["d"] == f"${{a{d}b}}" - assert config["c"]["e"] == f'"hello${{a{d}b}}"' - assert config["c"]["f"] == "${a}" - config2 = Config().from_str(config.to_str(), interpolate=True) - assert config2.is_interpolated - assert config2["c"]["d"] == 1 - assert config2["c"]["e"] == "hello1" - assert config2["c"]["f"] == {"b": 1} - config3 = config.interpolate() - assert config3.is_interpolated - assert config3["c"]["d"] == 1 - assert config3["c"]["e"] == "hello1" - assert config3["c"]["f"] == {"b": 1} - # Bad non-serializable value - cfg = {"x": {"y": numpy.asarray([[1, 2], [4, 5]], dtype="f"), "z": f"${{x{d}y}}"}} - with pytest.raises(ConfigValidationError): - Config(cfg).interpolate() - - -def test_config_no_interpolation_registry(): - config_str = """[a]\nbad = true\n[b]\n@cats = "catsie.v1"\nevil = ${a:bad}\n\n[c]\n d = ${b}""" - config = Config().from_str(config_str, interpolate=False) - assert not config.is_interpolated - assert config["b"]["evil"] == "${a:bad}" - assert config["c"]["d"] == "${b}" - filled = my_registry.fill(config) - resolved = my_registry.resolve(config) - assert resolved["b"] == "scratch!" - assert resolved["c"]["d"] == "scratch!" - assert filled["b"]["evil"] == "${a:bad}" - assert filled["b"]["cute"] is True - assert filled["c"]["d"] == "${b}" - interpolated = filled.interpolate() - assert interpolated.is_interpolated - assert interpolated["b"]["evil"] is True - assert interpolated["c"]["d"] == interpolated["b"] - config = Config().from_str(config_str, interpolate=True) - assert config.is_interpolated - filled = my_registry.fill(config) - resolved = my_registry.resolve(config) - assert resolved["b"] == "scratch!" - assert resolved["c"]["d"] == "scratch!" - assert filled["b"]["evil"] is True - assert filled["c"]["d"] == filled["b"] - # Resolving a non-interpolated filled config - config = Config().from_str(config_str, interpolate=False) - assert not config.is_interpolated - filled = my_registry.fill(config) - assert not filled.is_interpolated - assert filled["c"]["d"] == "${b}" - resolved = my_registry.resolve(filled) - assert resolved["c"]["d"] == "scratch!" - - -def test_config_deep_merge(): - config = {"a": "hello", "b": {"c": "d"}} - defaults = {"a": "world", "b": {"c": "e", "f": "g"}} - merged = Config(defaults).merge(config) - assert len(merged) == 2 - assert merged["a"] == "hello" - assert merged["b"] == {"c": "d", "f": "g"} - config = {"a": "hello", "b": {"@test": "x", "foo": 1}} - defaults = {"a": "world", "b": {"@test": "x", "foo": 100, "bar": 2}, "c": 100} - merged = Config(defaults).merge(config) - assert len(merged) == 3 - assert merged["a"] == "hello" - assert merged["b"] == {"@test": "x", "foo": 1, "bar": 2} - assert merged["c"] == 100 - config = {"a": "hello", "b": {"@test": "x", "foo": 1}, "c": 100} - defaults = {"a": "world", "b": {"@test": "y", "foo": 100, "bar": 2}} - merged = Config(defaults).merge(config) - assert len(merged) == 3 - assert merged["a"] == "hello" - assert merged["b"] == {"@test": "x", "foo": 1} - assert merged["c"] == 100 - # Test that leaving out the factory just adds to existing - config = {"a": "hello", "b": {"foo": 1}, "c": 100} - defaults = {"a": "world", "b": {"@test": "y", "foo": 100, "bar": 2}} - merged = Config(defaults).merge(config) - assert len(merged) == 3 - assert merged["a"] == "hello" - assert merged["b"] == {"@test": "y", "foo": 1, "bar": 2} - assert merged["c"] == 100 - # Test that switching to a different factory prevents the default from being added - config = {"a": "hello", "b": {"@foo": 1}, "c": 100} - defaults = {"a": "world", "b": {"@bar": "y"}} - merged = Config(defaults).merge(config) - assert len(merged) == 3 - assert merged["a"] == "hello" - assert merged["b"] == {"@foo": 1} - assert merged["c"] == 100 - config = {"a": "hello", "b": {"@foo": 1}, "c": 100} - defaults = {"a": "world", "b": "y"} - merged = Config(defaults).merge(config) - assert len(merged) == 3 - assert merged["a"] == "hello" - assert merged["b"] == {"@foo": 1} - assert merged["c"] == 100 - - -def test_config_deep_merge_variables(): - config_str = """[a]\nb= 1\nc = 2\n\n[d]\ne = ${a:b}""" - defaults_str = """[a]\nx = 100\n\n[d]\ny = 500""" - config = Config().from_str(config_str, interpolate=False) - defaults = Config().from_str(defaults_str) - merged = defaults.merge(config) - assert merged["a"] == {"b": 1, "c": 2, "x": 100} - assert merged["d"] == {"e": "${a:b}", "y": 500} - assert merged.interpolate()["d"] == {"e": 1, "y": 500} - # With variable in defaults: overwritten by new value - config = Config().from_str("""[a]\nb= 1\nc = 2""") - defaults = Config().from_str("""[a]\nb = 100\nc = ${a:b}""", interpolate=False) - merged = defaults.merge(config) - assert merged["a"]["c"] == 2 - - -def test_config_to_str_roundtrip(): - cfg = {"cfg": {"foo": False}} - config_str = Config(cfg).to_str() - assert config_str == "[cfg]\nfoo = false" - config = Config().from_str(config_str) - assert dict(config) == cfg - cfg = {"cfg": {"foo": "false"}} - config_str = Config(cfg).to_str() - assert config_str == '[cfg]\nfoo = "false"' - config = Config().from_str(config_str) - assert dict(config) == cfg - # Bad non-serializable value - cfg = {"cfg": {"x": numpy.asarray([[1, 2, 3, 4], [4, 5, 3, 4]], dtype="f")}} - config = Config(cfg) - with pytest.raises(ConfigValidationError): - config.to_str() - # Roundtrip with variables: preserve variables correctly (quoted/unquoted) - config_str = """[a]\nb = 1\n\n[c]\nd = ${a:b}\ne = \"hello${a:b}"\nf = "${a:b}\"""" - config = Config().from_str(config_str, interpolate=False) - assert config.to_str() == config_str - - -def test_config_is_interpolated(): - """Test that a config object correctly reports whether it's interpolated.""" - config_str = """[a]\nb = 1\n\n[c]\nd = ${a:b}\ne = \"hello${a:b}"\nf = ${a}""" - config = Config().from_str(config_str, interpolate=False) - assert not config.is_interpolated - config = config.merge(Config({"x": {"y": "z"}})) - assert not config.is_interpolated - config = Config(config) - assert not config.is_interpolated - config = config.interpolate() - assert config.is_interpolated - config = config.merge(Config().from_str(config_str, interpolate=False)) - assert not config.is_interpolated - - -@pytest.mark.parametrize( - "section_order,expected_str,expected_keys", - [ - # fmt: off - ([], "[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4\n\n[h]\ni = 5\n\n[j]\nk = 6", ["a", "h", "j"]), - (["j", "h", "a"], "[j]\nk = 6\n\n[h]\ni = 5\n\n[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4", ["j", "h", "a"]), - (["h"], "[h]\ni = 5\n\n[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4\n\n[j]\nk = 6", ["h", "a", "j"]) - # fmt: on - ], -) -def test_config_serialize_custom_sort(section_order, expected_str, expected_keys): - cfg = { - "j": {"k": 6}, - "a": {"b": 1, "d": {"e": 3}, "c": 2, "f": {"g": 4}}, - "h": {"i": 5}, - } - cfg_str = Config(cfg).to_str() - assert Config(cfg, section_order=section_order).to_str() == expected_str - keys = list(Config(section_order=section_order).from_str(cfg_str).keys()) - assert keys == expected_keys - keys = list(Config(cfg, section_order=section_order).keys()) - assert keys == expected_keys - - -def test_config_custom_sort_preserve(): - """Test that sort order is preserved when merging and copying configs, - or when configs are filled and resolved.""" - cfg = {"x": {}, "y": {}, "z": {}} - section_order = ["y", "z", "x"] - expected = "[y]\n\n[z]\n\n[x]" - config = Config(cfg, section_order=section_order) - assert config.to_str() == expected - config2 = config.copy() - assert config2.to_str() == expected - config3 = config.merge({"a": {}}) - assert config3.to_str() == f"{expected}\n\n[a]" - config4 = Config(config) - assert config4.to_str() == expected - config_str = """[a]\nb = 1\n[c]\n@cats = "catsie.v1"\nevil = true\n\n[t]\n x = 2""" - section_order = ["c", "a", "t"] - config5 = Config(section_order=section_order).from_str(config_str) - assert list(config5.keys()) == section_order - filled = my_registry.fill(config5) - assert filled.section_order == section_order - - -def test_config_pickle(): - config = Config({"foo": "bar"}, section_order=["foo", "bar", "baz"]) - data = pickle.dumps(config) - config_new = pickle.loads(data) - assert config_new == {"foo": "bar"} - assert config_new.section_order == ["foo", "bar", "baz"] - - -def test_config_fill_extra_fields(): - """Test that filling a config from a schema removes extra fields.""" - - class TestSchemaContent(BaseModel): - a: str - b: int - - class Config: - extra = "forbid" - - class TestSchema(BaseModel): - cfg: TestSchemaContent - - config = Config({"cfg": {"a": "1", "b": 2, "c": True}}) - with pytest.raises(ConfigValidationError): - my_registry.fill(config, schema=TestSchema) - filled = my_registry.fill(config, schema=TestSchema, validate=False)["cfg"] - assert filled == {"a": "1", "b": 2} - config2 = config.interpolate() - filled = my_registry.fill(config2, schema=TestSchema, validate=False)["cfg"] - assert filled == {"a": "1", "b": 2} - config3 = Config({"cfg": {"a": "1", "b": 2, "c": True}}, is_interpolated=False) - filled = my_registry.fill(config3, schema=TestSchema, validate=False)["cfg"] - assert filled == {"a": "1", "b": 2} - - class TestSchemaContent2(BaseModel): - a: str - b: int - - class Config: - extra = "allow" - - class TestSchema2(BaseModel): - cfg: TestSchemaContent2 - - filled = my_registry.fill(config, schema=TestSchema2, validate=False)["cfg"] - assert filled == {"a": "1", "b": 2, "c": True} - - -def test_config_validation_error_custom(): - class Schema(BaseModel): - hello: int - world: int - - config = {"hello": 1, "world": "hi!"} - with pytest.raises(ConfigValidationError) as exc_info: - my_registry._fill(config, Schema) - e1 = exc_info.value - assert e1.title == "Config validation error" - assert e1.desc is None - assert not e1.parent - assert e1.show_config is True - assert len(e1.errors) == 1 - assert e1.errors[0]["loc"] == ("world",) - assert e1.errors[0]["msg"] == "value is not a valid integer" - assert e1.errors[0]["type"] == "type_error.integer" - assert e1.error_types == set(["type_error.integer"]) - # Create a new error with overrides - title = "Custom error" - desc = "Some error description here" - e2 = ConfigValidationError.from_error(e1, title=title, desc=desc, show_config=False) - assert e2.errors == e1.errors - assert e2.error_types == e1.error_types - assert e2.title == title - assert e2.desc == desc - assert e2.show_config is False - assert e1.text != e2.text - - -def test_config_parsing_error(): - config_str = "[a]\nb c" - with pytest.raises(ConfigValidationError): - Config().from_str(config_str) - - -def test_config_fill_without_resolve(): - class BaseSchema(BaseModel): - catsie: int - - config = {"catsie": {"@cats": "catsie.v1", "evil": False}} - filled = my_registry.fill(config) - resolved = my_registry.resolve(config) - assert resolved["catsie"] == "meow" - assert filled["catsie"]["cute"] is True - with pytest.raises(ConfigValidationError): - my_registry.resolve(config, schema=BaseSchema) - filled2 = my_registry.fill(config, schema=BaseSchema) - assert filled2["catsie"]["cute"] is True - resolved = my_registry.resolve(filled2) - assert resolved["catsie"] == "meow" - # With unavailable function - class BaseSchema2(BaseModel): - catsie: Any - other: int = 12 - - config = {"catsie": {"@cats": "dog", "evil": False}} - filled3 = my_registry.fill(config, schema=BaseSchema2) - assert filled3["catsie"] == config["catsie"] - assert filled3["other"] == 12 - - -def test_config_dataclasses(): - @my_registry.cats("catsie.ragged") - def catsie_ragged(arg: Ragged): - return arg - - data = numpy.zeros((20, 4), dtype="f") - lengths = numpy.array([4, 2, 8, 1, 4], dtype="i") - ragged = Ragged(data, lengths) - config = {"cfg": {"@cats": "catsie.ragged", "arg": ragged}} - result = my_registry.resolve(config)["cfg"] - assert isinstance(result, Ragged) - assert list(result._get_starts_ends()) == [0, 4, 6, 14, 15, 19] - - -@pytest.mark.parametrize( - "greeting,value,expected", - [ - # simple substitution should go fine - [342, "${vars.a}", int], - ["342", "${vars.a}", str], - ["everyone", "${vars.a}", str], - ], -) -def test_config_interpolates(greeting, value, expected): - str_cfg = f""" - [project] - my_par = {value} - - [vars] - a = "something" - """ - overrides = {"vars.a": greeting} - cfg = Config().from_str(str_cfg, overrides=overrides) - assert type(cfg["project"]["my_par"]) == expected - - -@pytest.mark.parametrize( - "greeting,value,expected", - [ - # fmt: off - # simple substitution should go fine - ["hello 342", "${vars.a}", "hello 342"], - ["hello everyone", "${vars.a}", "hello everyone"], - ["hello tout le monde", "${vars.a}", "hello tout le monde"], - ["hello 42", "${vars.a}", "hello 42"], - # substituting an element in a list - ["hello 342", "[1, ${vars.a}, 3]", "hello 342"], - ["hello everyone", "[1, ${vars.a}, 3]", "hello everyone"], - ["hello tout le monde", "[1, ${vars.a}, 3]", "hello tout le monde"], - ["hello 42", "[1, ${vars.a}, 3]", "hello 42"], - # substituting part of a string - [342, "hello ${vars.a}", "hello 342"], - ["everyone", "hello ${vars.a}", "hello everyone"], - ["tout le monde", "hello ${vars.a}", "hello tout le monde"], - pytest.param("42", "hello ${vars.a}", "hello 42", marks=pytest.mark.xfail), - # substituting part of a implicit string inside a list - [342, "[1, hello ${vars.a}, 3]", "hello 342"], - ["everyone", "[1, hello ${vars.a}, 3]", "hello everyone"], - ["tout le monde", "[1, hello ${vars.a}, 3]", "hello tout le monde"], - pytest.param("42", "[1, hello ${vars.a}, 3]", "hello 42", marks=pytest.mark.xfail), - # substituting part of a explicit string inside a list - [342, "[1, 'hello ${vars.a}', '3']", "hello 342"], - ["everyone", "[1, 'hello ${vars.a}', '3']", "hello everyone"], - ["tout le monde", "[1, 'hello ${vars.a}', '3']", "hello tout le monde"], - pytest.param("42", "[1, 'hello ${vars.a}', '3']", "hello 42", marks=pytest.mark.xfail), - # more complicated example - [342, "[{'name':'x','script':['hello ${vars.a}']}]", "hello 342"], - ["everyone", "[{'name':'x','script':['hello ${vars.a}']}]", "hello everyone"], - ["tout le monde", "[{'name':'x','script':['hello ${vars.a}']}]", "hello tout le monde"], - pytest.param("42", "[{'name':'x','script':['hello ${vars.a}']}]", "hello 42", marks=pytest.mark.xfail), - # fmt: on - ], -) -def test_config_overrides(greeting, value, expected): - str_cfg = f""" - [project] - commands = {value} - - [vars] - a = "world" - """ - overrides = {"vars.a": greeting} - assert "${vars.a}" in str_cfg - cfg = Config().from_str(str_cfg, overrides=overrides) - assert expected in str(cfg) - - def test_arg_order_is_preserved(): str_cfg = """ [model] From 1befbd8227fd279b2a134d9d41e749e5736bdfa2 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Wed, 24 Aug 2022 12:06:10 +0200 Subject: [PATCH 2/4] Add reference to confection in website/docs/usage-config.md. --- website/docs/usage-config.md | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/website/docs/usage-config.md b/website/docs/usage-config.md index abb6951e4..74b2aefb7 100644 --- a/website/docs/usage-config.md +++ b/website/docs/usage-config.md @@ -22,6 +22,12 @@ also allows you to link the configuration system to functions in your code using a decorator. Thinc's config system is simpler and emphasizes a different workflow via a subset of Gin's functionality. + +Thinc's config system is wrapping and leveraging +[confection](https://github.com/explosion/confection), which provides this +functionality independently from on Thinc. + + ```ini @@ -654,11 +660,11 @@ resolved = registry.resolve( The main motivation for Thinc's configuration system was to eliminate hidden defaults and ensure that config settings are passed around consistently. This also means that config files should always define **all available settings**. -The [`registry.fill`](/docs/api-config#registry-fill) method also -resolves the config, but it leaves references to registered functions intact and -doesn't replace them with their return values. If type annotations and/or a base -schema are available, they will be used to parse the config and fill in any -missing values and defaults to create an up-to-date "master config". +The [`registry.fill`](/docs/api-config#registry-fill) method also resolves the +config, but it leaves references to registered functions intact and doesn't +replace them with their return values. If type annotations and/or a base schema +are available, they will be used to parse the config and fill in any missing +values and defaults to create an up-to-date "master config". Let's say you've updated your schema and scripts to use two additional optional settings. These settings should also be reflected in your config files so they @@ -677,9 +683,9 @@ class TrainingSchema(BaseModel): max_epochs: StrictInt = 100 ``` -Calling [`registry.fill`](/docs/api-config#registry-fill) with your -existing config will produce an updated version of it including the new settings -and their defaults: +Calling [`registry.fill`](/docs/api-config#registry-fill) with your existing +config will produce an updated version of it including the new settings and +their defaults: From 80daeeddd401e96181d090cbc3214211f8f6c387 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Thu, 25 Aug 2022 11:20:17 +0200 Subject: [PATCH 3/4] Update confection reference in docs. --- website/docs/usage-config.md | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/website/docs/usage-config.md b/website/docs/usage-config.md index 74b2aefb7..73a1638ac 100644 --- a/website/docs/usage-config.md +++ b/website/docs/usage-config.md @@ -12,21 +12,16 @@ And then once those settings are added, they become hard to remove later. Default values also become hard to change without breaking backwards compatibility. -To solve this problem, Thinc provides a config system that lets you easily -describe **arbitrary trees of objects**. The objects can be created via -**function calls you register** using a simple decorator syntax. You can even -version the functions you create, allowing you to make improvements without -breaking backwards compatibility. The most similar config system we're aware of -is [Gin](https://github.com/google/gin-config), which uses a similar syntax, and -also allows you to link the configuration system to functions in your code using -a decorator. Thinc's config system is simpler and emphasizes a different -workflow via a subset of Gin's functionality. - - -Thinc's config system is wrapping and leveraging -[confection](https://github.com/explosion/confection), which provides this -functionality independently from on Thinc. - +To solve this problem, Thinc leverages +[confection](https://github.com/explosion/confection) - a config system that +lets you easily describe **arbitrary trees of objects**. The objects can be +created via **function calls you register** using a simple decorator syntax. You +can even version the functions you create, allowing you to make improvements +without breaking backwards compatibility. The most similar config system we're +aware of is [Gin](https://github.com/google/gin-config), which uses a similar +syntax, and also allows you to link the configuration system to functions in +your code using a decorator. Thinc's config system is simpler and emphasizes a +different workflow via a subset of Gin's functionality. From 6104267e434fa613a42763ce7e3f016010a5549a Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Thu, 25 Aug 2022 13:06:45 +0200 Subject: [PATCH 4/4] Extend imports fro confection for backwards compatibility. --- thinc/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thinc/config.py b/thinc/config.py index 8affebec6..8c0e752c5 100644 --- a/thinc/config.py +++ b/thinc/config.py @@ -1,6 +1,6 @@ import catalogue import confection -from confection import Config, ConfigValidationError +from confection import Config, ConfigValidationError, Promise, VARIABLE_RE from .types import Decorator