diff --git a/ir_axioms/modules/similarity.py b/ir_axioms/modules/similarity.py index 3f31f46..fdc4ada 100644 --- a/ir_axioms/modules/similarity.py +++ b/ir_axioms/modules/similarity.py @@ -12,6 +12,7 @@ from ir_axioms import logger from ir_axioms.utils.nltk import download_nltk_dependencies +import wget import os DIR_PATH = os.path.dirname(os.path.realpath(__file__)) @@ -193,6 +194,10 @@ class MagnitudeTermSimilarityMixin(TermSimilarityMixin, ABC): @cached_property def _embeddings(self): + url = 'https://files.webis.de/data-in-production/data-research/ir-axioms/wiki-news-300d-1M.magnitude' # noqa: E501 + if not os.path.isfile(self.embeddings_path) and self.embeddings_path.endswith('wiki-news-300d-1M.magnitude'): + wget.download(url, out=self.embeddings_path) + return Magnitude(self.embeddings_path) @final @@ -202,5 +207,9 @@ def similarity(self, term1: str, term2: str): class FastTextWikiNewsTermSimilarityMixin(MagnitudeTermSimilarityMixin): - # wget via: https://files.webis.de/data-in-production/data-research/ir-axioms/wiki-news-300d-1M.magnitude # noqa: E501 embeddings_path: Final[str] = f"{DIR_PATH}/wiki-news-300d-1M.magnitude" + + def __init__(self): + super().__init__() + + diff --git a/pyproject.toml b/pyproject.toml index 774c646..97acf39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "trectools~=0.0.44", "typing-extensions~=4.0", "xxhash~=3.0", + "wget", ] dynamic = ["version"]