Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
ardunn committed Oct 11, 2019
1 parent 2b9ef9b commit b58ae35
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
8 changes: 4 additions & 4 deletions automatminer/automl/adaptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def best_models(self):
"""

if self._from_serialized:
if self.from_serialized:
return self._best_models
else:
self.greater_score_is_better = is_greater_better(
Expand Down Expand Up @@ -208,7 +208,7 @@ def backend(self):

@property
def best_pipeline(self):
if self._from_serialized:
if self.from_serialized:
# The TPOT backend is replaced by the best pipeline.
return self._backend
else:
Expand All @@ -235,7 +235,7 @@ def serialize(self) -> None:
# Necessary for getting best models post serialization
self._best_models = self.best_models
self._backend = self.best_pipeline
self._from_serialized = True
self.from_serialized = True

def deserialize(self) -> None:
"""
Expand All @@ -247,7 +247,7 @@ def deserialize(self) -> None:
"""
self._backend = _adaptor_tmp_backend
_tmp_backend = None
self._from_serialized = False
self.from_serialized = False


class SinglePipelineAdaptor(DFMLAdaptor, LoggableMixin):
Expand Down
15 changes: 13 additions & 2 deletions automatminer/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from pprint import pformat

import yaml

from automatminer import __version__
from automatminer import TPOTAdaptor, SinglePipelineAdaptor, FeatureReducer, \
AutoFeaturizer, DataCleaner
from automatminer.base import LoggableMixin, DFTransformer
Expand Down Expand Up @@ -77,6 +79,8 @@ class MatPipe(DFTransformer, LoggableMixin):
is_fit (bool): If True, the matpipe is fit. The matpipe should be
fit before being used to predict data.
version (str): The automatminer version used for serialization and
deserialization.
"""

def __init__(self, autofeaturizer=None, cleaner=None, reducer=None,
Expand Down Expand Up @@ -107,6 +111,7 @@ def __init__(self, autofeaturizer=None, cleaner=None, reducer=None,
self.is_fit = False
self.ml_type = None
self.target = None
self.version = __version__

# @staticmethod
# def from_preset(preset: str = 'express', **powerups):
Expand Down Expand Up @@ -430,7 +435,7 @@ def save(self, filename="mat.pipe"):
loggable._logger = temp_logger

@staticmethod
def load(filename, logger=True):
def load(filename, logger=True, supress_version_mismatch=False):
"""
Loads a matpipe that was saved.
Expand All @@ -439,19 +444,25 @@ def load(filename, logger=True):
using save).
logger (bool or logging.Logger): The logger to use for the loaded
matpipe.
supress_version_mismatch (bool): If False, throws an error when
there is a version mismatch between a serialized MatPipe and the
current automatminer version. If True, supresses this error.
Returns:
pipe (MatPipe): A MatPipe object.
"""
with open(filename, 'rb') as f:
pipe = pickle.load(f)

if pipe.version != __version__ and not supress_version_mismatch:
raise AutomatminerError("Version mismatch")

pipe.logger = logger

pipe.logger.info("Loaded MatPipe from file {}.".format(filename))

if hasattr(pipe.learner, "from_serialized"):
if pipe.learner._from_serialized:
if pipe.learner.from_serialized:
pipe.logger.warning(
"Only use this model to make predictions (do not "
"retrain!). Backend was serialzed as only the top model, "
Expand Down

0 comments on commit b58ae35

Please sign in to comment.