Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility with scikit-learn 1.6rc1 #720

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sklego/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import re
import sys

if sys.version_info >= (3, 8):
from importlib import metadata
else:
import importlib_metadata as metadata


__title__ = "sklego"
__version__ = metadata.version("scikit-lego")

SKLEARN_VERSION = tuple(int(re.sub(r"\D", "", str(v))) for v in metadata.version("scikit-learn").split("."))
42 changes: 39 additions & 3 deletions sklego/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_array, check_is_fitted, check_X_y

from sklego import SKLEARN_VERSION


class TrainOnlyTransformerMixin(TransformerMixin, BaseEstimator):
"""Mixin class for transformers that can handle training and test data differently.
Expand Down Expand Up @@ -79,9 +81,9 @@ def fit(self, X, y=None):
The fitted transformer.
"""
if y is None:
check_array(X, estimator=self)
validate_data(self, X)
else:
check_X_y(X, y, estimator=self, multi_output=True)
validate_data(self, X, y, multi_output=True)
self.X_hash_ = self._hash(X)
self.n_features_in_ = X.shape[1]
return self
Expand Down Expand Up @@ -145,7 +147,7 @@ def transform(self, X, y=None):
If the input dimension does not match the training dimension.
"""
check_is_fitted(self, ["X_hash_", "n_features_in_"])
check_array(X, estimator=self)
validate_data(self, X, reset=False)

if X.shape[1] != self.n_features_in_:
raise ValueError(f"Unexpected input dimension {X.shape[1]}, expected {self.n_features_in_}")
Expand Down Expand Up @@ -339,3 +341,37 @@ def sliding_window(sequence, window_size, step_size):
```
"""
return (sequence[pos : pos + window_size] for pos in range(0, len(sequence), step_size))


def validate_data(
estimator,
X="no_validation",
y="no_validation",
reset=True,
validate_separately=False,
skip_check_array=False,
y_required=False,
**check_params,
):
if SKLEARN_VERSION >= (1, 6):
from sklearn.utils.validation import validate_data

return validate_data(
estimator,
X=X,
y=y,
reset=reset,
validate_separately=validate_separately,
skip_check_array=skip_check_array,
**check_params,
)

else:
if y is None and y_required:
msg = f"This {estimator.__class__.__name__} estimator requires y to be passed, but the target y is None."
raise ValueError(msg)

if y is None or (isinstance(y, str) and y == "no_validation"):
return check_array(X, estimator=estimator, **check_params)
else:
return check_X_y(X=X, y=y, estimator=estimator, **check_params)
8 changes: 5 additions & 3 deletions sklego/decomposition/pca_reconstruction.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
from sklearn.base import BaseEstimator, OutlierMixin
from sklearn.decomposition import PCA
from sklearn.utils.validation import FLOAT_DTYPES, check_array, check_is_fitted
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted

from sklego.common import validate_data


class PCAOutlierDetection(OutlierMixin, BaseEstimator):
Expand Down Expand Up @@ -94,7 +96,7 @@ def fit(self, X, y=None):
ValueError
If `threshold` is `None`.
"""
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES)
X = validate_data(self, X, dtype=FLOAT_DTYPES)
if not self.threshold:
raise ValueError("The `threshold` value cannot be `None`.")

Expand Down Expand Up @@ -157,7 +159,7 @@ def predict(self, X):
array-like of shape (n_samples,)
The predicted data. 1 for inliers, -1 for outliers.
"""
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES)
X = validate_data(self, X, dtype=FLOAT_DTYPES)
check_is_fitted(self, ["pca_", "offset_"])
result = np.ones(X.shape[0])
result[self.difference(X) > self.threshold] = -1
Expand Down
21 changes: 17 additions & 4 deletions sklego/decomposition/umap_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

import numpy as np
from sklearn.base import BaseEstimator, OutlierMixin
from sklearn.utils.validation import FLOAT_DTYPES, check_array, check_is_fitted
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted

from sklego.common import validate_data


class UMAPOutlierDetection(OutlierMixin, BaseEstimator):
Expand Down Expand Up @@ -100,9 +102,9 @@ def fit(self, X, y=None):
- If `n_components` is less than 2.
- If `threshold` is `None`.
"""
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES)
X = validate_data(self, X, dtype=FLOAT_DTYPES)
if y is not None:
y = check_array(y, estimator=self, ensure_2d=False)
y = validate_data(self, y, ensure_2d=False)

if not self.threshold:
raise ValueError("The `threshold` value cannot be `None`.")
Expand Down Expand Up @@ -133,6 +135,7 @@ def difference(self, X):
The calculated difference.
"""
check_is_fitted(self, ["umap_", "offset_"])

reduced = self.umap_.transform(X)
diff = np.sum(np.abs(self.umap_.inverse_transform(reduced) - X), axis=1)
if self.variant == "relative":
Expand All @@ -155,7 +158,7 @@ def predict(self, X):
array-like of shape (n_samples,)
The predicted data. 1 for inliers, -1 for outliers.
"""
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES)
X = validate_data(self, X, dtype=FLOAT_DTYPES, reset=False)
check_is_fitted(self, ["umap_", "offset_"])
result = np.ones(X.shape[0])
result[self.difference(X) > self.threshold] = -1
Expand All @@ -172,3 +175,13 @@ def score_samples(self, X):

def _more_tags(self):
return {"non_deterministic": True}

def __sklearn_tags__(self):
from sklego import SKLEARN_VERSION

if SKLEARN_VERSION >= (1, 6):
tags = super().__sklearn_tags__()
tags.non_deterministic = True
return tags
else:
pass
21 changes: 16 additions & 5 deletions sklego/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

import numpy as np
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.utils import check_X_y
from sklearn.utils.validation import (
FLOAT_DTYPES,
check_array,
check_is_fitted,
check_random_state,
)

from sklego.common import validate_data


class RandomRegressor(RegressorMixin, BaseEstimator):
"""A `RandomRegressor` makes random predictions only based on the `y` value that is seen.
Expand Down Expand Up @@ -72,7 +72,7 @@ def fit(self, X: np.array, y: np.array) -> "RandomRegressor":
"""
if self.strategy not in self._ALLOWED_STRATEGIES:
raise ValueError(f"strategy {self.strategy} is not in {self._ALLOWED_STRATEGIES}")
X, y = check_X_y(X, y, estimator=self, dtype=FLOAT_DTYPES)
X, y = validate_data(self, X, y, dtype=FLOAT_DTYPES, y_required=True)
self.n_features_in_ = X.shape[1]

self.min_ = np.min(y)
Expand All @@ -99,9 +99,9 @@ def predict(self, X):
rs = check_random_state(self.random_state)
check_is_fitted(self, ["n_features_in_", "min_", "max_", "mu_", "sigma_"])

X = check_array(X, estimator=self, dtype=FLOAT_DTYPES)
X = validate_data(self, X, dtype=FLOAT_DTYPES, reset=False)
if X.shape[1] != self.n_features_in_:
raise ValueError(f"Unexpected input dimension {X.shape[1]}, expected {self.dim_}")
raise ValueError(f"Unexpected input dimension {X.shape[1]}, expected {self.n_features_in_}")

if self.strategy == "normal":
return rs.normal(self.mu_, self.sigma_, X.shape[0])
Expand All @@ -127,3 +127,14 @@ def allowed_strategies(self):

def _more_tags(self):
return {"poor_score": True, "non_deterministic": True}

def __sklearn_tags__(self):
from sklego import SKLEARN_VERSION

if SKLEARN_VERSION >= (1, 6):
tags = super().__sklearn_tags__()
tags.non_deterministic = True
tags.regressor_tags.poor_score = True
return tags
else:
pass
7 changes: 5 additions & 2 deletions sklego/feature_selection/mrmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from sklearn.base import BaseEstimator
from sklearn.feature_selection import f_classif, f_regression
from sklearn.feature_selection._base import SelectorMixin
from sklearn.utils.validation import check_is_fitted, check_X_y
from sklearn.utils.validation import check_is_fitted

from sklego.common import validate_data


def _redundancy_pearson(X, selected, left):
Expand Down Expand Up @@ -201,7 +203,8 @@ def fit(self, X, y):

k parameter is not integer type or is < n_features_in (X.shape[1]) or < 1
"""
X, y = check_X_y(X, y, dtype="numeric", y_numeric=True)
X, y = validate_data(self, X, y, dtype="numeric", y_numeric=True, y_required=True)

self._y_dtype = y.dtype

relevance = self._get_relevance
Expand Down
37 changes: 25 additions & 12 deletions sklego/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@
from sklearn.linear_model._base import LinearClassifierMixin
from sklearn.multiclass import OneVsOneClassifier, OneVsRestClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.utils import check_X_y
from sklearn.utils.validation import (
FLOAT_DTYPES,
_check_sample_weight,
check_array,
check_is_fitted,
check_X_y,
column_or_1d,
)

from sklego.common import validate_data


class LowessRegression(RegressorMixin, BaseEstimator):
"""`LowessRegression` estimator: LOWESS (Locally Weighted Scatterplot Smoothing) is a type of
Expand Down Expand Up @@ -96,7 +98,7 @@ def fit(self, X, y):
- If `span` is not between 0 and 1.
- If `sigma` is negative.
"""
X, y = check_X_y(X, y, estimator=self, dtype=FLOAT_DTYPES)
X, y = validate_data(self, X, y, dtype=FLOAT_DTYPES, y_required=True)
if self.span is not None:
if not 0 <= self.span <= 1:
raise ValueError(f"Param `span` must be 0 <= span <= 1, got: {self.span}")
Expand Down Expand Up @@ -138,7 +140,7 @@ def predict(self, X):
array-like of shape (n_samples,)
The predicted values.
"""
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES)
X = validate_data(self, X, dtype=FLOAT_DTYPES, reset=False)
check_is_fitted(self, ["X_", "y_"])

try:
Expand Down Expand Up @@ -233,7 +235,7 @@ def fit(self, X, y):
self : ProbWeightRegression
The fitted estimator.
"""
X, y = check_X_y(X, y, estimator=self, dtype=FLOAT_DTYPES)
X, y = validate_data(self, X, y, dtype=FLOAT_DTYPES, y_required=True)

# Construct the problem.
betas = cp.Variable(X.shape[1])
Expand Down Expand Up @@ -263,7 +265,7 @@ def predict(self, X):
array-like of shape (n_samples,)
The predicted data.
"""
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES)
X = validate_data(self, X, dtype=FLOAT_DTYPES, reset=False)
check_is_fitted(self, ["coef_"])
return np.dot(X, self.coef_)

Expand Down Expand Up @@ -381,7 +383,7 @@ def fit(self, X, y):
ValueError
If `effect` is not one of "linear", "quadratic" or "constant".
"""
X, y = check_X_y(X, y, estimator=self, dtype=FLOAT_DTYPES)
X, y = validate_data(self, X, y, dtype=FLOAT_DTYPES, y_required=True)
if self.effect not in self._ALLOWED_EFFECTS:
raise ValueError(f"effect {self.effect} must be in {self._ALLOWED_EFFECTS}")

Expand Down Expand Up @@ -458,7 +460,7 @@ def predict(self, X):
array-like of shape (n_samples,)
The predicted data.
"""
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES)
X = validate_data(self, X, dtype=FLOAT_DTYPES, reset=False)
check_is_fitted(self, ["coef_"])
return np.dot(X, self.coef_)

Expand Down Expand Up @@ -579,7 +581,7 @@ def fit(self, X, y):
if isinstance(X, nw.DataFrame):
self.sensitive_col_idx_ = [i for i, name in enumerate(X.columns) if name in self.sensitive_cols]

X, y = check_X_y(X, y, accept_large_sparse=False)
X, y = check_X_y(X, y, accept_large_sparse=False, estimator=self)
sensitive = X[:, self.sensitive_col_idx_]

if not self.train_sensitive_cols:
Expand Down Expand Up @@ -681,6 +683,16 @@ def decision_function(self, X):
def _more_tags(self):
return {"poor_score": True}

def __sklearn_tags__(self):
from sklego import SKLEARN_VERSION

if SKLEARN_VERSION >= (1, 6):
tags = super().__sklearn_tags__()
tags.classifier_tags.poor_score = True
return tags
else:
pass


class DemographicParityClassifier(LinearClassifierMixin, BaseEstimator):
r"""`DemographicParityClassifier` is a logistic regression classifier which can be constrained on demographic
Expand Down Expand Up @@ -970,8 +982,6 @@ def __init__(
self.fit_intercept = fit_intercept
self.copy_X = copy_X
self.positive = positive
if method not in ("SLSQP", "TNC", "L-BFGS-B"):
raise ValueError(f'method should be one of "SLSQP", "TNC", "L-BFGS-B", ' f"got {method} instead")
self.method = method

@abstractmethod
Expand Down Expand Up @@ -1021,6 +1031,9 @@ def fit(self, X, y, sample_weight=None):
self : BaseScipyMinimizeRegressor
Fitted linear model.
"""
if self.method not in {"SLSQP", "TNC", "L-BFGS-B"}:
msg = f"method should be one of 'SLSQP', 'TNC', 'L-BFGS-B', got {self.method} instead"
raise ValueError(msg)
X_, grad_loss, loss = self._prepare_inputs(X, sample_weight, y)

d = X_.shape[1] - self.n_features_in_ # This is either zero or one.
Expand Down Expand Up @@ -1051,7 +1064,7 @@ def _prepare_inputs(self, X, sample_weight, y):
This method is called by `fit` to prepare the inputs for the optimization problem. It adds an intercept column
to `X` if `fit_intercept=True`, and returns the loss function and its gradient.
"""
X, y = check_X_y(X, y, y_numeric=True)
X, y = validate_data(self, X, y, y_numeric=True, y_required=True)
sample_weight = _check_sample_weight(sample_weight, X)
self.n_features_in_ = X.shape[1]

Expand Down Expand Up @@ -1081,7 +1094,7 @@ def predict(self, X):
The predicted data.
"""
check_is_fitted(self)
X = check_array(X)
X = validate_data(self, X, reset=False)

return X @ self.coef_ + self.intercept_

Expand Down
Loading
Loading