Skip to content

Commit

Permalink
Support overriding registered functions in configs (explosion#12623)
Browse files Browse the repository at this point in the history
Support overriding registered functions in configs. Previously the registry name was parsed as a section name rather than as a registry name.
  • Loading branch information
adrianeboyd committed Jun 28, 2023
1 parent 1f663b7 commit b8d40ca
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 8 deletions.
50 changes: 50 additions & 0 deletions spacy/tests/serialize/test_serialize_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from spacy.ml.models import MaxoutWindowEncoder, MultiHashEmbed
from spacy.ml.models import build_tb_parser_model, build_Tok2Vec_model
from spacy.schemas import ConfigSchema, ConfigSchemaPretrain
from spacy.training import Example
from spacy.util import load_config, load_config_from_str
from spacy.util import load_model_from_config, registry

Expand Down Expand Up @@ -415,6 +416,55 @@ def test_config_overrides():
assert nlp.pipe_names == ["tok2vec", "tagger"]


@pytest.mark.filterwarnings("ignore:\\[W036")
def test_config_overrides_registered_functions():
nlp = spacy.blank("en")
nlp.add_pipe("attribute_ruler")
with make_tempdir() as d:
nlp.to_disk(d)
nlp_re1 = spacy.load(
d,
config={
"components": {
"attribute_ruler": {
"scorer": {"@scorers": "spacy.tagger_scorer.v1"}
}
}
},
)
assert (
nlp_re1.config["components"]["attribute_ruler"]["scorer"]["@scorers"]
== "spacy.tagger_scorer.v1"
)

@registry.misc("test_some_other_key")
def misc_some_other_key():
return "some_other_key"

nlp_re2 = spacy.load(
d,
config={
"components": {
"attribute_ruler": {
"scorer": {
"@scorers": "spacy.overlapping_labeled_spans_scorer.v1",
"spans_key": {"@misc": "test_some_other_key"},
}
}
}
},
)
assert nlp_re2.config["components"]["attribute_ruler"]["scorer"][
"spans_key"
] == {"@misc": "test_some_other_key"}
# run dummy evaluation (will return None scores) in order to test that
# the spans_key value in the nested override is working as intended in
# the config
example = Example.from_dict(nlp_re2.make_doc("a b c"), {})
scores = nlp_re2.evaluate([example])
assert "spans_some_other_key_f" in scores


def test_config_interpolation():
config = Config().from_str(nlp_config_string, interpolate=False)
assert config["corpora"]["train"]["path"] == "${paths.train}"
Expand Down
27 changes: 27 additions & 0 deletions spacy/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@ def test_minor_version(a1, a2, b1, b2, is_match):
{"training.batch_size": 128, "training.optimizer.learn_rate": 0.01},
{"training": {"batch_size": 128, "optimizer": {"learn_rate": 0.01}}},
),
(
{"attribute_ruler.scorer.@scorers": "spacy.tagger_scorer.v1"},
{"attribute_ruler": {"scorer": {"@scorers": "spacy.tagger_scorer.v1"}}},
),
],
)
def test_dot_to_dict(dot_notation, expected):
Expand All @@ -245,6 +249,29 @@ def test_dot_to_dict(dot_notation, expected):
assert util.dict_to_dot(result) == dot_notation


@pytest.mark.parametrize(
"dot_notation,expected",
[
(
{"token.pos": True, "token._.xyz": True},
{"token": {"pos": True, "_": {"xyz": True}}},
),
(
{"training.batch_size": 128, "training.optimizer.learn_rate": 0.01},
{"training": {"batch_size": 128, "optimizer": {"learn_rate": 0.01}}},
),
(
{"attribute_ruler.scorer": {"@scorers": "spacy.tagger_scorer.v1"}},
{"attribute_ruler": {"scorer": {"@scorers": "spacy.tagger_scorer.v1"}}},
),
],
)
def test_dot_to_dict_overrides(dot_notation, expected):
result = util.dot_to_dict(dot_notation)
assert result == expected
assert util.dict_to_dot(result, for_overrides=True) == dot_notation


def test_set_dot_to_object():
config = {"foo": {"bar": 1, "baz": {"x": "y"}}, "test": {"a": {"b": "c"}}}
with pytest.raises(KeyError):
Expand Down
28 changes: 20 additions & 8 deletions spacy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def load_model_from_path(
if not meta:
meta = get_model_meta(model_path)
config_path = model_path / "config.cfg"
overrides = dict_to_dot(config)
overrides = dict_to_dot(config, for_overrides=True)
config = load_config(config_path, overrides=overrides)
nlp = load_model_from_config(
config,
Expand Down Expand Up @@ -1479,14 +1479,19 @@ def dot_to_dict(values: Dict[str, Any]) -> Dict[str, dict]:
return result


def dict_to_dot(obj: Dict[str, dict]) -> Dict[str, Any]:
def dict_to_dot(obj: Dict[str, dict], *, for_overrides: bool = False) -> Dict[str, Any]:
"""Convert dot notation to a dict. For example: {"token": {"pos": True,
"_": {"xyz": True }}} becomes {"token.pos": True, "token._.xyz": True}.
values (Dict[str, dict]): The dict to convert.
obj (Dict[str, dict]): The dict to convert.
for_overrides (bool): Whether to enable special handling for registered
functions in overrides.
RETURNS (Dict[str, Any]): The key/value pairs.
"""
return {".".join(key): value for key, value in walk_dict(obj)}
return {
".".join(key): value
for key, value in walk_dict(obj, for_overrides=for_overrides)
}


def dot_to_object(config: Config, section: str):
Expand Down Expand Up @@ -1528,13 +1533,20 @@ def set_dot_to_object(config: Config, section: str, value: Any) -> None:


def walk_dict(
node: Dict[str, Any], parent: List[str] = []
node: Dict[str, Any], parent: List[str] = [], *, for_overrides: bool = False
) -> Iterator[Tuple[List[str], Any]]:
"""Walk a dict and yield the path and values of the leaves."""
"""Walk a dict and yield the path and values of the leaves.
for_overrides (bool): Whether to treat registered functions that start with
@ as final values rather than dicts to traverse.
"""
for key, value in node.items():
key_parent = [*parent, key]
if isinstance(value, dict):
yield from walk_dict(value, key_parent)
if isinstance(value, dict) and (
not for_overrides
or not any(value_key.startswith("@") for value_key in value)
):
yield from walk_dict(value, key_parent, for_overrides=for_overrides)
else:
yield (key_parent, value)

Expand Down

0 comments on commit b8d40ca

Please sign in to comment.