diff --git a/build_use.sh b/build_use.sh index e9e0d3b..16c9aac 100644 --- a/build_use.sh +++ b/build_use.sh @@ -8,9 +8,9 @@ cp meta/meta.json use_model/meta.json # create the package mkdir -p use_package python -m spacy package use_model use_package --force -pushd use_package/en_use-0.1.2 +pushd use_package/en_use-0.1.3 # zip it python setup.py sdist # install the tar.gz from dist/en_use-0.1.1.tar.gz -pip install dist/en_use-0.1.2.tar.gz +pip install dist/en_use-0.1.3.tar.gz popd \ No newline at end of file diff --git a/meta/meta.json b/meta/meta.json index 2d10ba0..2472169 100644 --- a/meta/meta.json +++ b/meta/meta.json @@ -1,7 +1,7 @@ { "lang": "en", "name": "use", - "version": "0.1.2", + "version": "0.1.3", "spacy_version": ">=2.2.3", "description": "Using TFHub USE", "author": "Martino Mensio", diff --git a/setup.py b/setup.py index 3d05841..06d7a4d 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,10 @@ from setuptools import setup, find_packages -# to import the version and also run the module one time (download cache model) +# to import the version import universal_sentence_encoder def setup_package(): + # run the module one time (download cache model) + universal_sentence_encoder.UniversalSentenceEncoder.create_wrapper() setup( name="universal_sentence_encoder", entry_points={ diff --git a/universal_sentence_encoder/__init__.py b/universal_sentence_encoder/__init__.py index 2becc76..4b49d82 100644 --- a/universal_sentence_encoder/__init__.py +++ b/universal_sentence_encoder/__init__.py @@ -9,7 +9,7 @@ from spacy.tokens import Span from spacy.matcher import Matcher -__version__ = "0.1.0" +__version__ = "0.1.3" from .language import UniversalSentenceEncoder UniversalSentenceEncoder.install_extensions() @@ -30,7 +30,8 @@ def __init__(self, nlp, enable_cache): # enable_cache = cfg.get('enable_cache', True) # UniversalSentenceEncoder.install_extensions() print('enable_cache', enable_cache) - UniversalSentenceEncoder.tf_wrapper.enable_cache = enable_cache + # load tfhub now (not compulsory but nice to have it loaded when running `spacy.load`) + UniversalSentenceEncoder.create_wrapper(enable_cache=enable_cache) def __call__(self, doc): UniversalSentenceEncoder.overwrite_vectors(doc) diff --git a/universal_sentence_encoder/language.py b/universal_sentence_encoder/language.py index 0afd8d4..f1874b6 100644 --- a/universal_sentence_encoder/language.py +++ b/universal_sentence_encoder/language.py @@ -16,10 +16,22 @@ class UniversalSentenceEncoder(Language): @staticmethod def install_extensions(): - UniversalSentenceEncoder.tf_wrapper = TFHubWrapper() + def get_encoding(token_span_doc): + # tokens, spans and docs all have the `.doc` property + wrapper = token_span_doc.doc._.tfhub_wrapper + if wrapper == None: + raise ValueError('Wrapper None') + return wrapper.embed_one(token_span_doc) + + # Placeholder for a reference to the wrapper + Doc.set_extension('tfhub_wrapper', default=None, force=True) # set the extension both on doc and span level - Span.set_extension('universal_sentence_encoding', getter=UniversalSentenceEncoder.tf_wrapper.embed_one, force=True) - Doc.set_extension('universal_sentence_encoding', getter=UniversalSentenceEncoder.tf_wrapper.embed_one, force=True) + # Token.set_extension('universal_sentence_encoding', getter=UniversalSentenceEncoder.tf_wrapper.embed_one, force=True) + # Span.set_extension('universal_sentence_encoding', getter=UniversalSentenceEncoder.tf_wrapper.embed_one, force=True) + # Doc.set_extension('universal_sentence_encoding', getter=UniversalSentenceEncoder.tf_wrapper.embed_one, force=True) + Token.set_extension('universal_sentence_encoding', getter=get_encoding, force=True) + Span.set_extension('universal_sentence_encoding', getter=get_encoding, force=True) + Doc.set_extension('universal_sentence_encoding', getter=get_encoding, force=True) @staticmethod def overwrite_vectors(doc): @@ -33,6 +45,9 @@ def overwrite_vectors(doc): # doc.user_hooks["vector_norm"] = lambda a: a._.universal_sentence_encoding # doc.user_span_hooks["vector_norm"] = lambda a: a._.universal_sentence_encoding # doc.user_token_hooks["vector_norm"] = lambda a: a._.universal_sentence_encoding + + # save a reference to the wrapper + doc._.tfhub_wrapper = TFHubWrapper.get_instance() return doc @@ -44,6 +59,17 @@ def create_nlp(language_base='en'): nlp.add_pipe(UniversalSentenceEncoder.overwrite_vectors) return nlp + # def __init__(self, vocab=True, make_doc=True, max_length=10 ** 6, meta={}, **kwargs): + # self.tf_wrapper = TFHubWrapper.get_instance() + # super.__init__(self, vocab, make_doc, max_length, meta=meta, **kwargs) + + @staticmethod + def create_wrapper(enable_cache=True): + """Helper method, run to do the loading now""" + UniversalSentenceEncoder.tf_wrapper = TFHubWrapper.get_instance() + # TODO the enable_cache with singleton is not a great idea + UniversalSentenceEncoder.tf_wrapper.enable_cache = enable_cache + class UniversalSentenceEncoderPipe(Pipe): pass @@ -51,6 +77,15 @@ class UniversalSentenceEncoderPipe(Pipe): class TFHubWrapper(object): embed_cache: Dict[str, Any] enable_cache = True + instance = None + + @staticmethod + def get_instance(): + # singleton + if not TFHubWrapper.instance: + TFHubWrapper.instance = TFHubWrapper() + return TFHubWrapper.instance + def __init__(self): self.embed_cache = {}