Skip to content

Commit

Permalink
More docstring tweaking
Browse files Browse the repository at this point in the history
  • Loading branch information
ml-evs committed Apr 28, 2021
1 parent 8523396 commit ee73681
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 9 deletions.
18 changes: 15 additions & 3 deletions modnet/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ class BaseMODNetModel(ABC):
"""

can_return_uncertainty = False
can_return_uncertainty: bool = False
"""Whether or not this model supports the ``return_unc`` parameter
in its `predict` method, which will enable returning uncertainties.
"""

def __init__(
self,
Expand Down Expand Up @@ -211,8 +214,17 @@ def fit_preset(
"""

@abstractmethod
def evaluate(self, test_data):
pass
def evaluate(self, test_data: MODData) -> pd.DataFrame:
"""Evaluates the target values for the passed `MODData` and returns the corresponding loss.
Parameters:
test_data: A featurized and feature-selected `MODData`
object containing the descriptors used in training.
Returns:
An array containing the defined losses for the model on the passed test data.
"""

def save(self, filename: str):
"""Save the `MODNetModel` to filename:
Expand Down
6 changes: 3 additions & 3 deletions modnet/models/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,15 @@ class OR only return the most probable class.
return df_mean

def evaluate(self, test_data: MODData) -> pd.DataFrame:
"""Evaluates the target values for the passed MODData by returning the corresponding loss.
"""Evaluates the target values for the passed `MODData` and returns the corresponding loss.
Parameters:
test_data: A featurized and feature-selected `MODData`
object containing the descriptors used in training.
Returns:
Loss score
An array containing the defined losses for the model on the passed test data.
"""
all_losses = np.zeros(self.n_models)
for i, m in enumerate(self.model):
Expand Down
8 changes: 5 additions & 3 deletions modnet/models/vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class MODNetModel(BaseMODNetModel):
"""

can_return_uncertainty = False

def build_model(
self,
targets: List,
Expand Down Expand Up @@ -536,15 +538,15 @@ class OR only return the most probable class.
return predictions

def evaluate(self, test_data: MODData) -> pd.DataFrame:
"""Evaluates the target values for the passed MODData by returning the corresponding loss.
"""Evaluates the target values for the passed `MODData` and returns the corresponding loss.
Parameters:
test_data: A featurized and feature-selected `MODData`
object containing the descriptors used in training.
Returns:
Loss score
An array containing the defined losses for the model on the passed test data.
"""
# prevents Nan predictions if some features are inf
x = (
Expand Down

0 comments on commit ee73681

Please sign in to comment.