Skip to content

Commit

Permalink
use _save_specials/_load_specials per type
Browse files Browse the repository at this point in the history
  • Loading branch information
gojomo committed Jul 23, 2020
1 parent ac9126d commit 0316084
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 82 deletions.
2 changes: 1 addition & 1 deletion gensim/models/doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ def load(cls, *args, **kwargs):
except AttributeError as ae:
logger.error(
"Model load error. Was model saved using code from an older Gensim Version? "
"Try loading older model using gensim-3.8.1, then re-saving, to restore "
"Try loading older model using gensim-3.8.3, then re-saving, to restore "
"compatibility with current code.")
raise ae

Expand Down
88 changes: 57 additions & 31 deletions gensim/models/fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@
import os

import numpy as np
import itertools as it
from numpy import ones, vstack, float32 as REAL
import six
from collections.abc import Iterable
Expand Down Expand Up @@ -822,7 +823,6 @@ def save(self, *args, **kwargs):
Load :class:`~gensim.models.fasttext.FastText` model.
"""
kwargs['ignore'] = kwargs.get('ignore', []) + ['buckets_word', ]
super(FastText, self).save(*args, **kwargs)

@classmethod
Expand All @@ -845,25 +845,15 @@ def load(cls, *args, **kwargs):
Save :class:`~gensim.models.fasttext.FastText` model.
"""
model = super(FastText, cls).load(*args, rethrow=True, **kwargs)

if not hasattr(model.wv, 'vectors_vocab_lockf') and hasattr(model.wv, 'vectors_vocab'):
# TODO: try trainables-location
model.wv.vectors_vocab_lockf = ones(1, dtype=REAL)
if not hasattr(model, 'vectors_ngrams_lockf') and hasattr(model.wv, 'vectors_ngrams'):
# TODO: try trainables-location
model.wv.vectors_ngrams_lockf = ones(1, dtype=REAL)
# fixup mistakenly overdimensioned gensim-3.x lockf arrays
if len(model.wv.vectors_vocab_lockf.shape) > 1:
model.wv.vectors_vocab_lockf = ones(1, dtype=REAL)
if len(model.wv.vectors_ngrams_lockf.shape) > 1:
model.wv.vectors_ngrams_lockf = ones(1, dtype=REAL)
if hasattr(model, 'bucket'):
del model.bucket # should only exist in one place: the wv subcomponent
if not hasattr(model.wv, 'buckets_word') or not model.wv.buckets_word:
model.wv.recalc_char_ngram_buckets()
return super(FastText, cls).load(*args, rethrow=True, **kwargs)

return model
def _load_specials(self, *args, **kwargs):
"""Handle special requirements of `.load()` protocol, usually up-converting older versions."""
super(FastText, self)._load_specials(*args, **kwargs)
if hasattr(self, 'bucket'):
# should only exist in one place: the wv subcomponent
self.wv.bucket = self.bucket
del self.bucket


class FastTextVocab(utils.SaveLoad):
Expand Down Expand Up @@ -1197,12 +1187,47 @@ def __init__(self, vector_size, min_n, max_n, bucket):

@classmethod
def load(cls, fname_or_handle, **kwargs):
model = super(FastTextKeyedVectors, cls).load(fname_or_handle, **kwargs)
if isinstance(model, FastTextKeyedVectors):
if not hasattr(model, 'compatible_hash') or model.compatible_hash is False:
raise TypeError("Pre-gensim-3.8.x Fasttext models with nonstandard hashing are no longer compatible."
"Loading into gensim-3.8.3 & re-saving may create a compatible model.")
return model
"""Load a previously saved `FastTextKeyedVectors` model.
Parameters
----------
fname : str
Path to the saved file.
Returns
-------
:class:`~gensim.models.fasttext.FastTextKeyedVectors`
Loaded model.
See Also
--------
:meth:`~gensim.models.fasttext.FastTextKeyedVectors.save`
Save :class:`~gensim.models.fasttext.FastTextKeyedVectors` model.
"""
return super(FastTextKeyedVectors, cls).load(fname_or_handle, **kwargs)

def _load_specials(self, *args, **kwargs):
"""Handle special requirements of `.load()` protocol, usually up-converting older versions."""
super(FastTextKeyedVectors, self)._load_specials(*args, **kwargs)
if not isinstance(self, FastTextKeyedVectors):
raise TypeError("Loaded object of type %s, not expected FastTextKeyedVectors" % type(self))
if not hasattr(self, 'compatible_hash') or self.compatible_hash is False:
raise TypeError("Pre-gensim-3.8.x Fasttext models with nonstandard hashing are no longer compatible."
"Loading into gensim-3.8.3 & re-saving may create a compatible model.")
if not hasattr(self, 'vectors_vocab_lockf') and hasattr(self, 'vectors_vocab'):
self.vectors_vocab_lockf = ones(1, dtype=REAL)
if not hasattr(self, 'vectors_ngrams_lockf') and hasattr(self, 'vectors_ngrams'):
self.vectors_ngrams_lockf = ones(1, dtype=REAL)
# fixup mistakenly overdimensioned gensim-3.x lockf arrays
if len(self.vectors_vocab_lockf.shape) > 1:
self.vectors_vocab_lockf = ones(1, dtype=REAL)
if len(self.vectors_ngrams_lockf.shape) > 1:
self.vectors_ngrams_lockf = ones(1, dtype=REAL)
if not hasattr(self, 'buckets_word') or not self.buckets_word:
self.recalc_char_ngram_buckets()
if not hasattr(self, 'vectors') or self.vectors is None:
self.adjust_vectors() # recompose full-word vectors

def __contains__(self, word):
"""Check if `word` or any character ngrams in `word` are present in the vocabulary.
Expand Down Expand Up @@ -1250,14 +1275,15 @@ def save(self, *args, **kwargs):
Load object.
"""
# don't bother storing the cached normalized vectors
ignore_attrs = [
'buckets_word',
'hash2index',
]
kwargs['ignore'] = kwargs.get('ignore', ignore_attrs)
super(FastTextKeyedVectors, self).save(*args, **kwargs)

def _save_specials(self, fname, separately, sep_limit, ignore, pickle_protocol, compress, subname):
"""Arrange any special handling for the gensim.utils.SaveLoad protocol"""
# don't save properties that are merely calculated from others
ignore = set(it.chain(ignore, ('buckets_word', 'vectors')))
return super(FastTextKeyedVectors, self)._save_specials(
fname, separately, sep_limit, ignore, pickle_protocol, compress, subname)

def get_vector(self, word, use_norm=False):
"""Get `word` representations in vector space, as a 1D numpy array.
Expand Down
110 changes: 60 additions & 50 deletions gensim/models/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
from collections import defaultdict, namedtuple
from types import GeneratorType
import threading
import itertools
import itertools as it
import copy

from gensim.utils import keep_vocab_item, call_on_class_only, deprecated
Expand Down Expand Up @@ -1788,20 +1788,14 @@ def save(self, *args, **kwargs):
Path to the file.
"""
# don't bother storing recalculable table
kwargs['ignore'] = kwargs.get('ignore', []) + ['cum_table', ]
super(Word2Vec, self).save(*args, **kwargs)

def get_latest_training_loss(self):
"""Get current value of the training loss.
Returns
-------
float
Current training loss.
"""
return self.running_training_loss
def _save_specials(self, fname, separately, sep_limit, ignore, pickle_protocol, compress, subname):
"""Arrange any special handling for the gensim.utils.SaveLoad protocol"""
# don't save properties that are merely calculated from others
ignore = set(it.chain(ignore, ('cum_table',)))
return super(Word2Vec, self)._save_specials(
fname, separately, sep_limit, ignore, pickle_protocol, compress, subname)

@classmethod
def load(cls, *args, rethrow=False, **kwargs):
Expand All @@ -1828,49 +1822,65 @@ def load(cls, *args, rethrow=False, **kwargs):
if not isinstance(model, Word2Vec):
rethrow = True
raise AttributeError("Model of type %s can't be loaded by %s" % (type(model), str(cls)))
# for backward compatibility
if not hasattr(model, 'ns_exponent'):
model.ns_exponent = 0.75
if model.negative and hasattr(model.wv, 'index2word'):
model.make_cum_table() # rebuild cum_table from vocabulary ## TODO: ???
if not hasattr(model, 'corpus_count'):
model.corpus_count = None
if not hasattr(model, 'corpus_total_words'):
model.corpus_total_words = None
if not hasattr(model.wv, 'vectors_lockf') and hasattr(model.wv, 'vectors'):
model.wv.vectors_lockf = getattr(model, 'vectors_lockf', np.ones(1, dtype=REAL))
if not hasattr(model, 'random'):
model.random = np.random.RandomState(model.seed)
if not hasattr(model, 'train_count'):
model.train_count = 0
model.total_train_time = 0
if not hasattr(model, 'epochs'):
model.epochs = model.iter
del model.iter
if not hasattr(model, 'max_final_vocab'):
model.max_final_vocab = None
if hasattr(model, 'vocabulary'): # re-integrate state that had been moved
for a in ('max_vocab_size', 'min_count', 'sample', 'sorted_vocab', 'null_word', 'raw_vocab'):
setattr(model, a, getattr(model.vocabulary, a))
del model.vocabulary
if hasattr(model, 'trainables'): # re-integrate state that had been moved
for a in ('hashfxn', 'layer1_size', 'seed', 'syn1neg', 'syn1'):
if hasattr(model.trainables, a):
setattr(model, a, getattr(model.trainables, a))
if hasattr(model, 'syn1'):
model.syn1 = model.syn1
del model.syn1
del model.trainables
return model
except AttributeError as ae:
if rethrow:
raise ae
logger.error(
"Model load error. Was model saved using code from an older Gensim Version? "
"Try loading older model using gensim-3.8.1, then re-saving, to restore "
"Try loading older model using gensim-3.8.3, then re-saving, to restore "
"compatibility with current code.")
raise ae

def _load_specials(self, *args, **kwargs):
"""Handle special requirements of `.load()` protocol, usually up-converting older versions."""
super(Word2Vec, self)._load_specials(*args, **kwargs)
# for backward compatibility, add/rearrange properties from prior versions
if not hasattr(self, 'ns_exponent'):
self.ns_exponent = 0.75
if self.negative and hasattr(self.wv, 'index_to_key'):
self.make_cum_table() # rebuild cum_table from vocabulary
if not hasattr(self, 'corpus_count'):
self.corpus_count = None
if not hasattr(self, 'corpus_total_words'):
self.corpus_total_words = None
if not hasattr(self.wv, 'vectors_lockf') and hasattr(self.wv, 'vectors'):
self.wv.vectors_lockf = getattr(self, 'vectors_lockf', np.ones(1, dtype=REAL))
if not hasattr(self, 'random'):
# use new instance of numpy's recommended generator/algorithm
self.random = np.random.default_rng(seed=self.seed)
if not hasattr(self, 'train_count'):
self.train_count = 0
self.total_train_time = 0
if not hasattr(self, 'epochs'):
self.epochs = self.iter
del self.iter
if not hasattr(self, 'max_final_vocab'):
self.max_final_vocab = None
if hasattr(self, 'vocabulary'): # re-integrate state that had been moved
for a in ('max_vocab_size', 'min_count', 'sample', 'sorted_vocab', 'null_word', 'raw_vocab'):
setattr(self, a, getattr(self.vocabulary, a))
del self.vocabulary
if hasattr(self, 'trainables'): # re-integrate state that had been moved
for a in ('hashfxn', 'layer1_size', 'seed', 'syn1neg', 'syn1'):
if hasattr(self.trainables, a):
setattr(self, a, getattr(self.trainables, a))
if hasattr(self, 'syn1'):
self.syn1 = self.syn1
del self.syn1
del self.trainables

def get_latest_training_loss(self):
"""Get current value of the training loss.
Returns
-------
float
Current training loss.
"""
return self.running_training_loss


class BrownCorpus(object):
def __init__(self, dirname):
Expand Down Expand Up @@ -1958,7 +1968,7 @@ def __iter__(self):
# Assume it is a file-like object and try treating it as such
# Things that don't have seek will trigger an exception
self.source.seek(0)
for line in itertools.islice(self.source, self.limit):
for line in it.islice(self.source, self.limit):
line = utils.to_unicode(line).split()
i = 0
while i < len(line):
Expand All @@ -1967,7 +1977,7 @@ def __iter__(self):
except AttributeError:
# If it didn't work like a file, use it as a string filename
with utils.open(self.source, 'rb') as fin:
for line in itertools.islice(fin, self.limit):
for line in it.islice(fin, self.limit):
line = utils.to_unicode(line).split()
i = 0
while i < len(line):
Expand Down Expand Up @@ -2021,7 +2031,7 @@ def __iter__(self):
for file_name in self.input_files:
logger.info('reading file %s', file_name)
with utils.open(file_name, 'rb') as fin:
for line in itertools.islice(fin, self.limit):
for line in it.islice(fin, self.limit):
line = utils.to_unicode(line).split()
i = 0
while i < len(line):
Expand Down

0 comments on commit 0316084

Please sign in to comment.