Skip to content

Commit

Permalink
version 0.1.3: loading the tf_wrapper just when loading the model
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinoMensio committed Feb 19, 2020
1 parent d665ade commit 23a1258
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 9 deletions.
4 changes: 2 additions & 2 deletions build_use.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ cp meta/meta.json use_model/meta.json
# create the package
mkdir -p use_package
python -m spacy package use_model use_package --force
pushd use_package/en_use-0.1.2
pushd use_package/en_use-0.1.3
# zip it
python setup.py sdist
# install the tar.gz from dist/en_use-0.1.1.tar.gz
pip install dist/en_use-0.1.2.tar.gz
pip install dist/en_use-0.1.3.tar.gz
popd
2 changes: 1 addition & 1 deletion meta/meta.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"lang": "en",
"name": "use",
"version": "0.1.2",
"version": "0.1.3",
"spacy_version": ">=2.2.3",
"description": "Using TFHub USE",
"author": "Martino Mensio",
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from setuptools import setup, find_packages
# to import the version and also run the module one time (download cache model)
# to import the version
import universal_sentence_encoder

def setup_package():
# run the module one time (download cache model)
universal_sentence_encoder.UniversalSentenceEncoder.create_wrapper()
setup(
name="universal_sentence_encoder",
entry_points={
Expand Down
5 changes: 3 additions & 2 deletions universal_sentence_encoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from spacy.tokens import Span
from spacy.matcher import Matcher

__version__ = "0.1.0"
__version__ = "0.1.3"

from .language import UniversalSentenceEncoder
UniversalSentenceEncoder.install_extensions()
Expand All @@ -30,7 +30,8 @@ def __init__(self, nlp, enable_cache):
# enable_cache = cfg.get('enable_cache', True)
# UniversalSentenceEncoder.install_extensions()
print('enable_cache', enable_cache)
UniversalSentenceEncoder.tf_wrapper.enable_cache = enable_cache
# load tfhub now (not compulsory but nice to have it loaded when running `spacy.load`)
UniversalSentenceEncoder.create_wrapper(enable_cache=enable_cache)

def __call__(self, doc):
UniversalSentenceEncoder.overwrite_vectors(doc)
Expand Down
41 changes: 38 additions & 3 deletions universal_sentence_encoder/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,22 @@ class UniversalSentenceEncoder(Language):

@staticmethod
def install_extensions():
UniversalSentenceEncoder.tf_wrapper = TFHubWrapper()
def get_encoding(token_span_doc):
# tokens, spans and docs all have the `.doc` property
wrapper = token_span_doc.doc._.tfhub_wrapper
if wrapper == None:
raise ValueError('Wrapper None')
return wrapper.embed_one(token_span_doc)

# Placeholder for a reference to the wrapper
Doc.set_extension('tfhub_wrapper', default=None, force=True)
# set the extension both on doc and span level
Span.set_extension('universal_sentence_encoding', getter=UniversalSentenceEncoder.tf_wrapper.embed_one, force=True)
Doc.set_extension('universal_sentence_encoding', getter=UniversalSentenceEncoder.tf_wrapper.embed_one, force=True)
# Token.set_extension('universal_sentence_encoding', getter=UniversalSentenceEncoder.tf_wrapper.embed_one, force=True)
# Span.set_extension('universal_sentence_encoding', getter=UniversalSentenceEncoder.tf_wrapper.embed_one, force=True)
# Doc.set_extension('universal_sentence_encoding', getter=UniversalSentenceEncoder.tf_wrapper.embed_one, force=True)
Token.set_extension('universal_sentence_encoding', getter=get_encoding, force=True)
Span.set_extension('universal_sentence_encoding', getter=get_encoding, force=True)
Doc.set_extension('universal_sentence_encoding', getter=get_encoding, force=True)

@staticmethod
def overwrite_vectors(doc):
Expand All @@ -33,6 +45,9 @@ def overwrite_vectors(doc):
# doc.user_hooks["vector_norm"] = lambda a: a._.universal_sentence_encoding
# doc.user_span_hooks["vector_norm"] = lambda a: a._.universal_sentence_encoding
# doc.user_token_hooks["vector_norm"] = lambda a: a._.universal_sentence_encoding

# save a reference to the wrapper
doc._.tfhub_wrapper = TFHubWrapper.get_instance()
return doc


Expand All @@ -44,13 +59,33 @@ def create_nlp(language_base='en'):
nlp.add_pipe(UniversalSentenceEncoder.overwrite_vectors)
return nlp

# def __init__(self, vocab=True, make_doc=True, max_length=10 ** 6, meta={}, **kwargs):
# self.tf_wrapper = TFHubWrapper.get_instance()
# super.__init__(self, vocab, make_doc, max_length, meta=meta, **kwargs)

@staticmethod
def create_wrapper(enable_cache=True):
"""Helper method, run to do the loading now"""
UniversalSentenceEncoder.tf_wrapper = TFHubWrapper.get_instance()
# TODO the enable_cache with singleton is not a great idea
UniversalSentenceEncoder.tf_wrapper.enable_cache = enable_cache

class UniversalSentenceEncoderPipe(Pipe):
pass


class TFHubWrapper(object):
embed_cache: Dict[str, Any]
enable_cache = True
instance = None

@staticmethod
def get_instance():
# singleton
if not TFHubWrapper.instance:
TFHubWrapper.instance = TFHubWrapper()
return TFHubWrapper.instance


def __init__(self):
self.embed_cache = {}
Expand Down

0 comments on commit 23a1258

Please sign in to comment.