Skip to content

Commit

Permalink
Make check_is_fitted work again
Browse files Browse the repository at this point in the history
`check_is_fitted` was changed (see #570) in a way that we won't raise
a `NotInitializedError` anymore despite the net not being
initialized. This is now solved by basically porting the old
`check_is_fitted` behavior to skorch.
  • Loading branch information
BenjaminBossan committed Dec 5, 2019
1 parent 47f0c5b commit e9f35a9
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions skorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import numpy as np
from scipy import sparse
from sklearn.utils import safe_indexing
from sklearn.exceptions import NotFittedError
from sklearn.utils.validation import check_is_fitted as sklearn_check_is_fitted
import torch
from torch.nn.utils.rnn import PackedSequence
from torch.utils.data.dataset import Subset
Expand Down Expand Up @@ -494,15 +492,13 @@ def check_is_fitted(estimator, attributes, msg=None, all_or_any=all):
msg = ("This %(name)s instance is not initialized yet. Call "
"'initialize' or 'fit' with appropriate arguments "
"before using this method.")
try:
sklearn_check_is_fitted(
estimator=estimator,
attributes=attributes,
msg=msg,
all_or_any=all_or_any,
)
except NotFittedError as e:
raise NotInitializedError(str(e))


if not isinstance(attributes, (list, tuple)):
attributes = [attributes]

if not all_or_any([hasattr(estimator, attr) for attr in attributes]):
raise NotInitializedError(msg % {'name': type(estimator).__name__})


class TeeGenerator:
Expand Down

0 comments on commit e9f35a9

Please sign in to comment.