Skip to content

Commit

Permalink
Tweak docstrings and formatting, move pickle functions to base class
Browse files Browse the repository at this point in the history
  • Loading branch information
ml-evs committed Apr 28, 2021
1 parent ed0f00f commit 8523396
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 13 deletions.
79 changes: 73 additions & 6 deletions modnet/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
"""

from abc import ABC, abstractmethod
from typing import List, Dict, Tuple, Optional, Callable
from typing import List, Dict, Tuple, Optional, Callable, Any
from pathlib import Path

import tensorflow as tf
import numpy as np
import pandas as pd

from modnet import __version__
Expand All @@ -34,7 +35,6 @@ class BaseMODNetModel(ABC):
target_names: The list of targets names that the model
was trained for.
"""

can_return_uncertainty = False
Expand Down Expand Up @@ -154,8 +154,61 @@ def predict(self, test_data, return_prob=False, return_unc=False):
pass

@abstractmethod
def fit_preset(self, *args, **kwargs):
pass
def fit_preset(
self,
data: MODData,
presets: List[Dict[str, Any]] = None,
val_fraction: float = 0.15,
verbose: int = 0,
classification: bool = False,
refit: bool = True,
fast: bool = False,
nested: int = 5,
callbacks: List[Any] = None,
n_jobs: Optional[int] = None,
) -> Tuple[
List[List[Any]],
np.ndarray,
Optional[List[float]],
List[List[float]],
Dict[str, Any],
]:
"""Chooses an optimal hyper-parametered MODNet model from different presets.
This function implements the "inner loop" of a cross-validation workflow. By
modifying the `nested` argument, it can be run in full nested mode (i.e.
train n_fold * n_preset models) or just with a simple random hold-out set.
The data is first fitted on several well working MODNet presets
with a validation set (10% of the furnished data by default).
Sets the `self.model` attribute to the model with the lowest mean validation loss across
all folds.
Args:
data: MODData object contain training and validation samples.
presets: A list of dictionaries containing custom presets.
verbose: The verbosity level to pass to tf.keras
val_fraction: The fraction of the data to use for validation.
classification: Whether or not we are performing classification.
refit: Whether or not to refit the final model for each fold with
the best-performing settings.
fast: Used for debugging. If `True`, only fit the first 2 presets and
reduce the number of epochs.
nested: integer specifying whether or not to perform a full nested CV. If 0,
a simple validation split is performed based on val_fraction argument.
If an integer, use this number of inner CV folds, ignoring the `val_fraction` argument.
Note: If set to 1, the value will be overwritten to a default of 5 folds.
n_jobs: number of jobs for multiprocessing
Returns:
- A list of length num_outer_folds containing lists of MODNet models of length num_inner_folds.
- A list of validation losses achieved by the best model for each fold during validation (excluding refit).
- The learning curve of the final (refitted) model (or `None` if `refit` is `False`)
- A nested list of learning curves for each trained model of lengths (num_outer_folds, num_inner folds).
- The settings of the best-performing preset.
"""

@abstractmethod
def evaluate(self, test_data):
Expand All @@ -176,15 +229,29 @@ def save(self, filename: str):
self._restore_model()
LOG.info(f"Model successfully saved as {filename}!")

def _make_picklable(self):
"""Transforms inner Keras model to JSON (serialization) such that model becomes picklable."""
model_json = self.model.to_json()
model_weights = self.model.get_weights()
self.model = (model_json, model_weights)

def _restore_model(self):
"""Restore the inner keras model from JSON (deserialization) to a full Keras model."""

model_json, model_weights = self.model
self.model = tf.keras.models.model_from_json(model_json)
self.model.set_weights(model_weights)

@staticmethod
def load(filename: str):
def load(filename: str) -> "BaseMODNetModel":
"""Load `MODNetModel` object pickled by the `.save(...)` method.
If the filename ends in "tgz", "bz2" or "zip", the pickle
will be decompressed accordingly by `pandas.read_pickle(...)`.
Returns:
The loaded `MODNetModel` object.
"""
pickled_data = None

Expand Down Expand Up @@ -219,6 +286,6 @@ def load(filename: str):
return pickled_data

raise ValueError(
f"File {filename} did not contain compatible data to create a MODNetModel object, "
f"File {filename} did not contain compatible data to create a `BaseMODNetModel` object, "
f"instead found {pickled_data.__class__.__name__}."
)
9 changes: 2 additions & 7 deletions modnet/models/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,17 +403,12 @@ def fit_preset(
return models, val_losses, best_learning_curve, learning_curves, best_preset

def _make_picklable(self):
"""
transforms inner keras model to jsons so that th MODNet object becomes picklable.
"""

"""Calls ``model._make_pickleable(...)`` on all underlying models in the ensemble."""
for m in self.model:
m._make_picklable()

def _restore_model(self):
"""
restore inner keras model after running make_picklable
"""
"""Calls ``model._restore_model(...)`` on all underlying models in the ensemble."""

for m in self.model:
m._restore_model()
Expand Down
1 change: 1 addition & 0 deletions modnet/models/vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,7 @@ def _restore_model(self):
self.model = tf.keras.models.model_from_json(model_json)
self.model.set_weights(model_weights)


def validate_model(
train_data=None,
val_data=None,
Expand Down

0 comments on commit 8523396

Please sign in to comment.