diff --git a/constituent_treelib/core.py b/constituent_treelib/core.py index 6c18af8..2e9af1d 100644 --- a/constituent_treelib/core.py +++ b/constituent_treelib/core.py @@ -4,17 +4,15 @@ import spacy import benepar import huspacy -import fasttext import contractions from nltk import Tree from pathlib import Path import importlib.resources from enum import Enum, auto -from fasttext import FastText +from langid.langid import LanguageIdentifier, model from typing import List, Dict, Set, Union, Generator # Package imports -from . import lang_models from .errors import * from .export import export_figure @@ -271,6 +269,9 @@ def __init__(self, sentence: Union[str, BracketedTree, Tree], nlp: spacy.Languag raise SentenceError("The given sentence is either none or empty. Please provide a valid sentence in order " "to instantiate a ConstituentTree object.") + # Load the language detector model. + self.lang_det = LanguageIdentifier.from_modelstring(model, norm_probs=True) + # Detect the language of the given sentence in order to load the correct spaCy and benepar models. detected_language = Language.Unsupported @@ -385,9 +386,9 @@ def __init__(self, sentence: Union[str, BracketedTree, Tree], nlp: spacy.Languag f"(a string, an nltk.Tree or a BracketedTree) must be provided. Type of the " f"given sentence: {type(sentence).__name__}.") - def detect_language(self, text: str, append_proba: bool = False, round_precision: int = 3, top_k_matches: int = 1, - model: FastText._FastText = None): - """Detects the language of the given text using FastText. + + def detect_language(self, text: str, append_proba: bool = False, round_precision: int = 3, top_k_matches: int = 1): + """Detects the language of the given text using the pythob lib langid. Args: text: The text whose language is to be detected. @@ -399,32 +400,25 @@ def detect_language(self, text: str, append_proba: bool = False, round_precision top_k_matches: Number of k most likely detected languages. By default (k=1), the language with the highest detection probability is returned. - model: The desired language detection model. - Returns: - The language of the given text (optionally, the top-k detected languages and the detection probability). - """ - # Suppress FastText's annoying deprecation warning. - FastText.eprint = lambda x: None - - # Load the default compressed model if no other model is provided. The none-compressed model can be - # downloaded from: https://fasttext.cc/docs/en/language-identification.html - # A performance comparison of the two models can be found at https://fasttext.cc/blog/2017/10/02/blog-post.html - if not model: - with importlib.resources.path(lang_models, "lid.176.ftz") as resource: - model = fasttext.load_model(str(resource)) - + The language of the given text (optionally, the top-k detected languages and the detection probability). + """ + + predictions = self.lang_det.rank(text) + + if top_k_matches > len(predictions): + raise ValueError(f"The given 'top_k_matches' exceeds the number of langid's known languages. " + "Consider: top_k_matches < {len(predictions)}.") + + predictions = predictions[0:top_k_matches] result = [] - predictions = model.predict([text], k=top_k_matches) - predictions = list(zip(*predictions[0], *predictions[1])) for lang, proba in predictions: - lang = lang.replace('__label__', '') lang = self.lang_dict[lang] if lang in self.lang_dict else Language.Unsupported proba = round(proba, round_precision) result.append((lang, proba)) if append_proba else result.append(lang) - return result[0] if top_k_matches == 1 else result - + return result[0] if top_k_matches == 1 else result + def detect_spacy_langauge(self, nlp: spacy.Language = None): """ Translates the language identifier of the internal spaCy pipeline into a corresponding ConstituentTreelib.Langauge object. diff --git a/constituent_treelib/lang_models/__init__.py b/constituent_treelib/lang_models/__init__.py deleted file mode 100644 index 23c7f8a..0000000 --- a/constituent_treelib/lang_models/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# -*- coding: UTF-8 -*- -import warnings -from constituent_treelib.core import * -warnings.filterwarnings("ignore", message=".*torch_struct.distributions.TreeCRF.*") diff --git a/constituent_treelib/lang_models/lid.176.ftz b/constituent_treelib/lang_models/lid.176.ftz deleted file mode 100644 index 1fb85b3..0000000 Binary files a/constituent_treelib/lang_models/lid.176.ftz and /dev/null differ diff --git a/pyproject.toml b/pyproject.toml index ff6ad6d..9880773 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "constituent_treelib" -version = "0.0.6" +version = "0.0.7" authors = [ { name="Oren Halvani" } ] @@ -19,7 +19,7 @@ classifiers = [ dependencies = [ "benepar", "contractions", - "fasttext", + "langid", "nltk", "pytest", "streamlit", diff --git a/requirements.txt b/requirements.txt index d813da6..c9af970 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ benepar contractions -fasttext +langid huspacy nltk pdfkit diff --git a/tests/test_ctl_core.py b/tests/test_ctl_core.py index e37e3f1..c6bcb86 100644 --- a/tests/test_ctl_core.py +++ b/tests/test_ctl_core.py @@ -79,14 +79,6 @@ def test_error_nlp_pipeline_without_benepar_component(self): sentence = "I will not instantiate a ConstituentTree object with an invalid nlp pipeline ever again." ConstituentTree(sentence, nlp=self.defect_nlp) - def test_error_nlp_pipeline_models_language_mismatch(self): - benepar_model = "benepar_fr2" - self.defect_nlp.add_pipe("benepar", config={"model": benepar_model}) - - with pytest.raises(LanguageError): - sentence = "I will not instantiate a ConstituentTree object with an invalid nlp pipeline ever again." - ConstituentTree(sentence, nlp=self.defect_nlp) - def test_error_nlp_pipeline_models_sentence_language_mismatch(self): with pytest.raises(LanguageError): sentence = "Huch, das war jetzt nicht gewollt." @@ -170,5 +162,5 @@ def test_svg_export(self): ConstituentTree(sentence, self.nlp).export_tree(file_name) with open(file_name, "rb") as file: md5_hash = hashlib.md5(file.read()).hexdigest() - assert md5_hash == "75cdbbeda69e84df53b6d07428e54244" + assert md5_hash == "d3e9fdbe78fee450f212d605584f3b2a" os.remove(file_name)