Skip to content

Commit

Permalink
WIP: low hanging fix
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Dec 12, 2024
1 parent 13b20df commit a52bec1
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 11 deletions.
6 changes: 4 additions & 2 deletions sklego/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,8 +970,6 @@ def __init__(
self.fit_intercept = fit_intercept
self.copy_X = copy_X
self.positive = positive
if method not in ("SLSQP", "TNC", "L-BFGS-B"):
raise ValueError(f'method should be one of "SLSQP", "TNC", "L-BFGS-B", ' f"got {method} instead")
self.method = method

@abstractmethod
Expand Down Expand Up @@ -1021,6 +1019,10 @@ def fit(self, X, y, sample_weight=None):
self : BaseScipyMinimizeRegressor
Fitted linear model.
"""
if self.method not in {"SLSQP", "TNC", "L-BFGS-B"}:
msg = f"method should be one of 'SLSQP', 'TNC', 'L-BFGS-B', got {self.method} instead"
raise ValueError(msg)

X_, grad_loss, loss = self._prepare_inputs(X, sample_weight, y)

d = X_.shape[1] - self.n_features_in_ # This is either zero or one.
Expand Down
4 changes: 4 additions & 0 deletions sklego/meta/decay_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ def _is_classifier(self):
"""Checks if the wrapped estimator is a classifier."""
return any(["ClassifierMixin" in p.__name__ for p in type(self.model).__bases__])

def _is_regressor(self):
"""Checks if the wrapped estimator is a regressor."""
return any(["RegressorMixin" in p.__name__ for p in type(self.model).__bases__])

@property
def _estimator_type(self):
"""Computes `_estimator_type` dynamically from the wrapped model."""
Expand Down
6 changes: 5 additions & 1 deletion sklego/meta/grouped_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def fit(self, X, y=None):
self.groups_ = as_list(self.groups) if self.groups is not None else []

X = nw.from_native(X, strict=False, eager_only=True)
self.n_features_in_ = X.shape[1]

if isinstance(X, nw.DataFrame):
self.feature_names_out_ = [c for c in X.columns if c not in self.groups_]
Expand Down Expand Up @@ -193,9 +194,12 @@ def transform(self, X):
array-like of shape (n_samples, n_features)
Data transformed per group.
"""
check_is_fitted(self, ["fallback_", "transformers_"])
check_is_fitted(self, ["n_features_in_", "transformers_"])

X = nw.from_native(X, strict=False, eager_only=True)
if X.shape[1] != self.n_features_in_:
raise ValueError(f"X has {X.shape[1]} features, expected {self.n_features_in_} features.")

frame = parse_X_y(X, y=None, groups=self.groups_, check_X=self.check_X, **self._check_kwargs).drop(
"__sklego_target__"
)
Expand Down
17 changes: 12 additions & 5 deletions sklego/meta/zero_inflated_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def fit(self, X, y, sample_weight=None):
If all train target entirely consists of zeros and `handle_zero="error"`
"""
X, y = check_X_y(X, y)
self._check_n_features(X, reset=True)
self.n_features_in_ = X.shape[1]

if not is_classifier(self.classifier):
raise ValueError(
f"`classifier` has to be a classifier. Received instance of {type(self.classifier)} instead."
Expand Down Expand Up @@ -173,9 +174,12 @@ def predict(self, X):
array-like of shape (n_samples,)
The predicted values.
"""
check_is_fitted(self)
check_is_fitted(self, ["n_features_in_", "classifier_", "regressor_"])
X = check_array(X)
self._check_n_features(X, reset=False)

if X.shape[1] != self.n_features_in_:
msg = f"Unexpected input dimension {X.shape[1]}, expected {self.n_features_in_}"
raise ValueError(msg)

output = np.zeros(len(X))
non_zero_indices = np.where(self.classifier_.predict(X))[0]
Expand Down Expand Up @@ -211,9 +215,12 @@ def score_samples(self, X):
The predicted risk.
"""

check_is_fitted(self)
check_is_fitted(self, ["n_features_in_", "classifier_", "regressor_"])
X = check_array(X)
self._check_n_features(X, reset=True)

if X.shape[1] != self.n_features_in_:
msg = f"Unexpected input dimension {X.shape[1]}, expected {self.n_features_in_}"
raise ValueError(msg)

non_zero_proba = self.classifier_.predict_proba(X)[:, 1]
expected_impact = self.regressor_.predict(X)
Expand Down
8 changes: 8 additions & 0 deletions sklego/naive_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def predict(self, X):
"""
check_is_fitted(self, ["gmms_", "classes_", "n_features_in_"])
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES)

if self.n_features_in_ != X.shape[1]:
raise ValueError(f"number of columns {X.shape[1]} does not match fit size {self.n_features_in_}")

return self.classes_[self.predict_proba(X).argmax(axis=1)]

def predict_proba(self, X: np.ndarray):
Expand Down Expand Up @@ -284,6 +288,10 @@ def predict(self, X):
"""
check_is_fitted(self, ["gmms_", "classes_", "n_features_in_"])
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES)

if self.n_features_in_ != X.shape[1]:
raise ValueError(f"number of columns {X.shape[1]} does not match fit size {self.n_features_in_}")

return self.classes_[self.predict_proba(X).argmax(axis=1)]

def predict_proba(self, X: np.ndarray):
Expand Down
5 changes: 2 additions & 3 deletions sklego/preprocessing/columncapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,6 @@ def __init__(
discard_infs=False,
copy=True,
):
self._check_quantile_range(quantile_range)
self._check_interpolation(interpolation)

self.quantile_range = quantile_range
self.interpolation = interpolation
self.discard_infs = discard_infs
Expand All @@ -124,6 +121,8 @@ def fit(self, X, y=None):
ValueError
If `X` contains non-numeric columns.
"""
self._check_quantile_range(self.quantile_range)
self._check_interpolation(self.interpolation)
X = check_array(X, copy=True, force_all_finite=False, dtype=FLOAT_DTYPES, estimator=self)

# If X contains infs, we need to replace them by nans before computing quantiles
Expand Down

0 comments on commit a52bec1

Please sign in to comment.