Skip to content

Commit

Permalink
Merge pull request #399 from roedoejet/dev.dhd/activate_typing
Browse files Browse the repository at this point in the history
Enable type-checking and correct some resulting problems
  • Loading branch information
joanise authored Sep 12, 2024
2 parents 5631210 + bbcd1e8 commit b315a6c
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 70 deletions.
8 changes: 4 additions & 4 deletions g2p/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ async def convert(sid, message):


@SIO.on("table event", namespace="/table") # type: ignore
async def change_table(sid, message):
async def change_table(sid, message) -> None:
"""Change the lookup table"""
LOGGER.debug("/table: %s", message)
if "in_lang" not in message or "out_lang" not in message:
Expand All @@ -250,7 +250,7 @@ async def change_table(sid, message):
elif message["in_lang"] == "custom" or message["out_lang"] == "custom":
# These are only used to generate JSON to send to the client,
# so it's safe to create a list of references to the same thing.
mappings = [
mapping_dicts = [
{"in": "", "out": "", "context_before": "", "context_after": ""}
] * DEFAULT_N
abbs = [[""] * 6] * DEFAULT_N
Expand All @@ -272,7 +272,7 @@ async def change_table(sid, message):
"table response",
[
{
"mappings": mappings,
"mappings": mapping_dicts,
"abbs": abbs,
"kwargs": kwargs,
}
Expand All @@ -292,7 +292,7 @@ async def change_table(sid, message):
{
"mappings": x.plain_mapping(),
"abbs": expand_abbreviations_format(x.abbreviations),
"kwargs": x.model_dump(exclude=["alignments"]),
"kwargs": x.model_dump(exclude={"alignments"}),
}
for x in mappings
],
Expand Down
6 changes: 3 additions & 3 deletions g2p/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def generate_mapping( # noqa: C901
from_langs,
to_langs,
distance,
):
) -> None:
"""Generate a new mapping from existing mappings in the g2p system.
This command has different modes of operation.
Expand Down Expand Up @@ -354,7 +354,7 @@ def generate_mapping( # noqa: C901
except MappingMissing as e:
raise click.BadParameter(
f'Cannot find IPA mapping from "{in_lang}" to "{out_lang}": {e}',
param_hint=["IN_LANG", "OUT_LANG"],
param_hint=("IN_LANG", "OUT_LANG"), # type: ignore
)
source_mappings.append(source_mapping)

Expand Down Expand Up @@ -769,7 +769,7 @@ def update_schema(out_dir):
context_settings=CONTEXT_SETTINGS,
short_help="Scan a document for unknown characters.",
)
def scan(lang, path):
def scan(lang, path) -> None:
"""Scan a document for non target language characters.
Displays the set of un-mapped characters in a document.
Expand Down
69 changes: 36 additions & 33 deletions g2p/mappings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import os
import re
from copy import deepcopy
from functools import partial
from pathlib import Path
from typing import Dict, List, Pattern, Union
from typing import Callable, Dict, List, Pattern, Union

import yaml
from pydantic import BaseModel
Expand Down Expand Up @@ -56,8 +57,11 @@ def model_post_init(self, *_args, **_kwargs) -> None:
)
# load rules from path
if self.rules_path is not None and not self.rules:
self.rules = load_from_file(self.rules_path)
# This is required so that we don't keep escaping special characters for example
# make sure self.rules is always a List[Rule] like we say it is!
self.rules = [Rule(**obj) for obj in load_from_file(self.rules_path)]
# Process the rules, keeping only non-empty ones, and
# expanding abbreviations. This is also required so that
# we don't keep escaping special characters for example
self.rules = self.process_model_specs()
elif self.type == MAPPING_TYPE.lexicon:
# load alignments from path
Expand Down Expand Up @@ -146,7 +150,7 @@ def plain_mapping(self):
"""
return [rule.export_to_dict() for rule in self.rules]

def process_model_specs(self): # noqa: C901
def process_model_specs(self) -> List[Rule]:
"""Process all model specifications"""
if self.as_is is not None:
appropriate_setting = (
Expand Down Expand Up @@ -176,34 +180,36 @@ def process_model_specs(self): # noqa: C901
self.rules = sorted(
# Temporarily normalize to NFD for heuristic sorting of NFC-defined rules
self.rules,
key=lambda x: (
len(normalize(strip_index_notation(x.rule_input), "NFD"))
if isinstance(x, Rule)
else len(normalize(x["in"], "NFD"))
),
key=lambda x: len(normalize(strip_index_notation(x.rule_input), "NFD")),
reverse=True,
)

def apply_to_attributes(rule: Rule, func: Callable, *attrs):
for k in attrs:
value = getattr(rule, k)
if value: # won't be None since default is ""
setattr(rule, k, func(value))

non_empty_mappings: List[Rule] = []
for i, rule in enumerate(self.rules):
if isinstance(rule, dict):
rule = Rule(**rule)
# We explicitly exclude match_pattern and
# intermediate_form when saving rules. Seeing either of
# them is a programmer error.
assert (
rule.match_pattern is None
), "Either match_pattern was specified explicitly or process_model_specs was called more than once"
assert (
rule.intermediate_form is None
), "Either intermediate_form was specified explicitly or process_model_specs was called more than once"
# Expand Abbreviations
if (
self.abbreviations
and self.rules
and "match_pattern" not in self.rules[0]
):
for key in [
if self.abbreviations:
apply_to_attributes(
rule,
partial(expand_abbreviations, abbs=self.abbreviations),
"rule_input",
"context_before",
"context_after",
]:
setattr(
rule,
key,
expand_abbreviations(getattr(rule, key), self.abbreviations),
)
)
# Reverse Rule
if self.reverse:
rule.rule_input, rule.rule_output = rule.rule_output, rule.rule_input
Expand All @@ -214,19 +220,14 @@ def process_model_specs(self): # noqa: C901
rule = escape_special_characters(rule)
# Unicode Normalization
if self.norm_form != NORM_FORM_ENUM.none:
for k in [
apply_to_attributes(
rule,
partial(normalize, norm_form=self.norm_form.value),
"rule_input",
"rule_output",
"context_before",
"context_after",
]:
value = getattr(rule, k)
if value:
setattr(
rule,
k,
normalize(value, self.norm_form.value),
)
)
# Prevent Feeding
if self.prevent_feeding or rule.prevent_feeding:
rule.intermediate_form = self._string_to_pua(rule.rule_output, i)
Expand Down Expand Up @@ -439,7 +440,9 @@ def export_to_dict(self):
return {"mappings": [mapping.export_to_dict() for mapping in self.mappings]}

@staticmethod
def load_mapping_config_from_path(path_to_mapping_config: Union[str, Path]):
def load_mapping_config_from_path(
path_to_mapping_config: Union[str, Path]
) -> "MappingConfig":
"""Loads a mapping configuration from a path, if you just want one specific mapping
from the config, you can try Mapping.load_mapping_from_path instead.
"""
Expand Down
Binary file modified g2p/mappings/langs/langs.json.gz
Binary file not shown.
57 changes: 34 additions & 23 deletions g2p/mappings/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from copy import deepcopy
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Pattern, Tuple, TypeVar, Union
from typing import Any, Dict, List, Optional, Pattern, Tuple, TypeVar, Union, cast

import regex as re
import yaml
Expand All @@ -27,6 +27,7 @@
field_validator,
model_validator,
)
from typing_extensions import Literal

from g2p import exceptions
from g2p.log import LOGGER
Expand All @@ -53,12 +54,19 @@ class Rule(BaseModel):
prevent_feeding: bool = False
"""Whether to prevent the rule from feeding other rules"""

match_pattern: Optional[Pattern] = None
"""An automatically generated match_pattern based on the rule_input, context_before and context_after"""

intermediate_form: Optional[str] = None
"""An optional intermediate form. Should be automatically generated only when prevent_feeding is True"""
match_pattern: Optional[Pattern] = Field(
None,
exclude=True,
# Don't include this in the docs because it's generated, and would require a schema update
# description="""An automatically generated match_pattern based on the rule_input, context_before and context_after""",
)

intermediate_form: Optional[str] = Field(
None,
exclude=True,
# Don't include this in the docs because it's generated, and would require a schema update
# description="""An intermediate form, automatically generated only when prevent_feeding is True""",
)
comment: Optional[str] = None
"""An optional comment about the rule."""

Expand All @@ -68,8 +76,6 @@ def export_to_dict(
self, exclude=None, exclude_none=True, exclude_defaults=True, by_alias=True
):
"""All the options for exporting are tedious to keep track of so this is a helper function"""
if exclude is None:
exclude = {"match_pattern": True, "intermediate_form": True}
return self.model_dump(
exclude=exclude,
exclude_none=exclude_none,
Expand Down Expand Up @@ -126,26 +132,27 @@ def expand_abbreviations_format(data):
return lines


def normalize(inp: str, norm_form: str):
def normalize(inp: str, norm_form: Union[str, None]):
"""Normalize to NFC(omposed) or NFD(ecomposed).
Also, find any Unicode Escapes & decode 'em!
"""
if norm_form not in ["none", "NFC", "NFD", "NFKC", "NFKD"]:
raise exceptions.InvalidNormalization(normalize)
elif norm_form is None or norm_form == "none":
if norm_form is None or norm_form == "none":
return unicode_escape(inp)
else:
normalized = ud.normalize(norm_form, unicode_escape(inp))
if normalized != inp:
LOGGER.debug(
"The string %s was normalized to %s using the %s standard and by decoding any Unicode escapes. "
"Note that this is not necessarily the final stage of normalization.",
inp,
normalized,
norm_form,
)
return normalized
if norm_form not in ["NFC", "NFD", "NFKC", "NFKD"]:
raise exceptions.InvalidNormalization(normalize)
# Sadly mypy doesn't do narrowing to literals properly
norm_form = cast(Literal["NFC", "NFD", "NFKC", "NFKD"], norm_form)
normalized = ud.normalize(norm_form, unicode_escape(inp))
if normalized != inp:
LOGGER.debug(
"The string %s was normalized to %s using the %s standard and by decoding any Unicode escapes. "
"Note that this is not necessarily the final stage of normalization.",
inp,
normalized,
norm_form,
)
return normalized


# compose_indices is generic because we would like to propagate the
Expand Down Expand Up @@ -177,6 +184,8 @@ def normalize_to_NFD_with_indices(
) -> Tuple[str, List[Tuple[int, int]]]:
"""Normalize to NFD and return the indices mapping input to output characters"""
assert norm_form in ("NFD", "NFKD")
# Sadly mypy doesn't do narrowing to literals properly
norm_form = cast(Literal["NFD", "NFKD"], norm_form)
result = ""
indices = []
for i, c in enumerate(inp):
Expand All @@ -192,6 +201,8 @@ def normalize_to_NFC_with_indices(
) -> Tuple[str, List[Tuple[int, int]]]:
"""Normalize to NFC and return the indices mapping input to output characters"""
assert norm_form in ("NFC", "NFKC")
# Sadly mypy doesn't do narrowing to literals properly
norm_form = cast(Literal["NFC", "NFKC"], norm_form)
inp_nfc = ud.normalize(norm_form, inp)
NFD_form = norm_form[:-1] + "D" # NFC->NFD or NFKC->NFKD
inp_nfd, indices_to_nfd = normalize_to_NFD_with_indices(inp, NFD_form)
Expand Down
31 changes: 30 additions & 1 deletion g2p/tests/test_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io
import json
import os
import re
import unicodedata as ud
from contextlib import redirect_stderr
from tempfile import NamedTemporaryFile
Expand All @@ -12,9 +13,10 @@
from pydantic import ValidationError

from g2p import exceptions
from g2p.exceptions import InvalidNormalization
from g2p.log import LOGGER
from g2p.mappings import Mapping, Rule
from g2p.mappings.utils import NORM_FORM_ENUM, RULE_ORDERING_ENUM
from g2p.mappings.utils import NORM_FORM_ENUM, RULE_ORDERING_ENUM, normalize
from g2p.tests.public import __file__ as public_data
from g2p.transducer import Transducer

Expand Down Expand Up @@ -57,6 +59,14 @@ def test_normalization(self):
self.assertEqual(self.test_mapping_no_norm.rules[1].rule_input, "\u0061\u0301")
self.assertEqual(self.test_mapping_no_norm.rules[1].rule_output, "\u0061\u0301")

def test_utils_normalize(self):
"""Explicitly test our custom normalize function."""
self.assertEqual(normalize(r"\u0061", None), "a")
self.assertEqual(normalize("\u010d", "NFD"), "\u0063\u030c")
self.assertEqual(normalize("\u0063\u030c", "NFC"), "\u010d")
with self.assertRaises(InvalidNormalization):
normalize("FOOBIE", "BLETCH")

def test_json_map(self):
json_map = Mapping(
rules=self.json_map["map"],
Expand Down Expand Up @@ -397,6 +407,25 @@ def test_g2p_studio_csv(self):
)
os.unlink(tf.name)

def test_no_reprocess(self):
"""Ensure that attempting to reprocess a mapping is an error."""
with self.assertRaises(AssertionError):
self.test_mapping_norm.process_model_specs()
with self.assertRaises(ValidationError):
_ = Mapping(
rules=[{"in": "a", "out": "b", "match_pattern": re.compile("XOR OTA")}]
)
with self.assertRaises(ValidationError):
_ = Mapping(
rules=[
{
"in": "a",
"out": "b",
"intermediate_form": re.compile("HACKEM MUCHE"),
}
]
)


if __name__ == "__main__":
main()
12 changes: 6 additions & 6 deletions g2p/tests/time_panphon.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ def getPanphonDistanceSingleton1():

def getPanphonDistanceSingleton2():
if not hasattr(getPanphonDistanceSingleton2, "value"):
setattr(getPanphonDistanceSingleton2, "value", panphon.distance.Distance())
getPanphonDistanceSingleton2.value = panphon.distance.Distance()
return getPanphonDistanceSingleton2.value


for iters in (1, 1, 10, 100, 1000, 10000):
with CodeTimer(f"getPanphonDistanceSingleton1() {iters} times"):
for i in range(iters):
for _ in range(iters):
dst = getPanphonDistanceSingleton1()
with CodeTimer(f"getPanphonDistanceSingleton2() {iters} times"):
for i in range(iters):
for _ in range(iters):
dst = getPanphonDistanceSingleton2()

for words in (1, 10):
Expand All @@ -53,14 +53,14 @@ def getPanphonDistanceSingleton2():

with CodeTimer(f"is_panphon() on 1 word {words} times"):
string = "ei"
for i in range(words):
for _ in range(words):
is_panphon(string)

for iters in (1, 10):
with CodeTimer(f"dst init {iters} times"):
for i in range(iters):
for _ in range(iters):
dst = panphon.distance.Distance()

for iters in (1, 10, 100, 1000):
with CodeTimer(f"Transducer(Mapping(id='panphon_preprocessor')) {iters} times"):
panphon_preprocessor = Transducer(Mapping(id="panphon_preprocessor"))
panphon_preprocessor = Transducer(Mapping(id="panphon_preprocessor", rules=[]))

0 comments on commit b315a6c

Please sign in to comment.