Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default to pickle protocol 4 when saving models #3065

Merged
merged 4 commits into from
Mar 9, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gensim/similarities/annoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(self, model=None, num_trees=None):
raise ValueError("Only a Word2Vec, Doc2Vec, FastText or KeyedVectors instance can be used")
self._build_from_model(kv.get_normed_vectors(), kv.index_to_key, kv.vector_size)

def save(self, fname, protocol=2):
def save(self, fname, protocol=utils.PICKLE_PROTOCOL):
"""Save AnnoyIndexer instance to disk.

Parameters
Expand Down
3 changes: 2 additions & 1 deletion gensim/similarities/nmslib.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
except ImportError:
raise ImportError("NMSLIB not installed. To use the NMSLIB indexer, please run `pip install nmslib`.")

from gensim import utils
from gensim.models.doc2vec import Doc2Vec
from gensim.models.word2vec import Word2Vec
from gensim.models.fasttext import FastText
Expand Down Expand Up @@ -141,7 +142,7 @@ def __init__(self, model, index_params=None, query_time_params=None):
else:
raise ValueError("model must be a Word2Vec, Doc2Vec, FastText or KeyedVectors instance")

def save(self, fname, protocol=2):
def save(self, fname, protocol=utils.PICKLE_PROTOCOL):
"""Save this NmslibIndexer instance to a file.

Parameters
Expand Down
25 changes: 19 additions & 6 deletions gensim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@

logger = logging.getLogger(__name__)

# When pickling objects for persistence, use this protocol by default.
# Note that users won't be able to load models saved with high protocols on older environments that do
# not support that protocol (e.g. Python 2).
# In the rare cases where this matters, users can explicitly pass `model.save(pickle_protocol=2)`.
# See also https://github.com/RaRe-Technologies/gensim/pull/3065
PICKLE_PROTOCOL = 4

PAT_ALPHABETIC = re.compile(r'(((?![\d])\w)+)', re.UNICODE)
RE_HTML_ENTITY = re.compile(r'&(#?)([xX]?)(\w{1,8});', re.UNICODE)
Expand Down Expand Up @@ -567,7 +573,10 @@ def _adapt_by_suffix(fname):
compress, suffix = (True, 'npz') if fname.endswith('.gz') or fname.endswith('.bz2') else (False, 'npy')
return compress, lambda *args: '.'.join(args + (suffix,))

def _smart_save(self, fname, separately=None, sep_limit=10 * 1024**2, ignore=frozenset(), pickle_protocol=2):
def _smart_save(
self, fname,
separately=None, sep_limit=10 * 1024**2, ignore=frozenset(), pickle_protocol=PICKLE_PROTOCOL,
):
"""Save the object to a file. Used internally by :meth:`gensim.utils.SaveLoad.save()`.

Parameters
Expand Down Expand Up @@ -595,8 +604,9 @@ def _smart_save(self, fname, separately=None, sep_limit=10 * 1024**2, ignore=fro
"""
compress, subname = SaveLoad._adapt_by_suffix(fname)

restores = self._save_specials(fname, separately, sep_limit, ignore, pickle_protocol,
compress, subname)
restores = self._save_specials(
fname, separately, sep_limit, ignore, pickle_protocol, compress, subname,
)
try:
pickle(self, fname, protocol=pickle_protocol)
finally:
Expand Down Expand Up @@ -711,7 +721,10 @@ def _save_specials(self, fname, separately, sep_limit, ignore, pickle_protocol,
raise
return restores + [(self, asides)]

def save(self, fname_or_handle, separately=None, sep_limit=10 * 1024**2, ignore=frozenset(), pickle_protocol=2):
def save(
self, fname_or_handle,
separately=None, sep_limit=10 * 1024**2, ignore=frozenset(), pickle_protocol=PICKLE_PROTOCOL,
):
"""Save the object to a file.

Parameters
Expand Down Expand Up @@ -1410,7 +1423,7 @@ def smart_extension(fname, ext):
return fname


def pickle(obj, fname, protocol=2):
def pickle(obj, fname, protocol=PICKLE_PROTOCOL):
"""Pickle object `obj` to file `fname`, using smart_open so that `fname` can be on S3, HDFS, compressed etc.

Parameters
Expand All @@ -1420,7 +1433,7 @@ def pickle(obj, fname, protocol=2):
fname : str
Path to pickle file.
protocol : int, optional
Pickle protocol number. Default is 2 in order to support compatibility across python 2.x and 3.x.
Pickle protocol number.

"""
with open(fname, 'wb') as fout: # 'b' for binary, needed on Windows
Expand Down