Skip to content

Commit

Permalink
check_X_y with changed check_array
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Dec 16, 2024
1 parent 6290c99 commit 54e7664
Show file tree
Hide file tree
Showing 21 changed files with 98 additions and 39 deletions.
52 changes: 52 additions & 0 deletions sklego/_sklearn_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,57 @@ def check_array(
**kwargs,
)

def check_X_y(
X,
y,
accept_sparse=False,
*,
accept_large_sparse=True,
dtype="numeric",
order=None,
copy=False,
force_writeable=False,
ensure_all_finite=True,
ensure_2d=True,
allow_nd=False,
multi_output=False,
ensure_min_samples=1,
ensure_min_features=1,
y_numeric=False,
estimator=None,
):
from sklearn.utils.validation import _check_estimator_name, _check_y, check_consistent_length

if y is None:
if estimator is None:
estimator_name = "estimator"
else:
estimator_name = _check_estimator_name(estimator)
raise ValueError(f"{estimator_name} requires y to be passed, but the target y is None")

X = check_array(
X,
accept_sparse=accept_sparse,
accept_large_sparse=accept_large_sparse,
dtype=dtype,
order=order,
copy=copy,
force_writeable=force_writeable,
ensure_all_finite=ensure_all_finite,
ensure_2d=ensure_2d,
allow_nd=allow_nd,
ensure_min_samples=ensure_min_samples,
ensure_min_features=ensure_min_features,
estimator=estimator,
input_name="X",
)

y = _check_y(y, multi_output=multi_output, y_numeric=y_numeric, estimator=estimator)

check_consistent_length(X, y)

return X, y

# tags infrastructure
@dataclass(**_dataclass_args())
class InputTags:
Expand Down Expand Up @@ -516,5 +567,6 @@ def parametrize_with_checks(
_check_feature_names, # noqa: F401
_check_n_features, # noqa: F401
check_array, # noqa: F401
check_X_y, # noqa: F401
validate_data, # noqa: F401
)
4 changes: 2 additions & 2 deletions sklego/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_is_fitted, check_X_y
from sklearn.utils.validation import check_is_fitted

from sklego._sklearn_compat import _check_n_features, check_array
from sklego._sklearn_compat import _check_n_features, check_array, check_X_y


class TrainOnlyTransformerMixin(TransformerMixin, BaseEstimator):
Expand Down
12 changes: 9 additions & 3 deletions sklego/decomposition/umap_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,17 @@ def fit(self, X, y=None):
- If `n_components` is less than 2.
- If `threshold` is `None`.
"""
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES, ensure_2d=True)
X = check_array(
X,
estimator=self,
dtype=FLOAT_DTYPES,
ensure_2d=True,
ensure_min_features=self.n_components,
ensure_min_samples=2,
)
_check_n_features(self, X, reset=True)
if y is not None:
y = check_array(y, estimator=self)
y = check_array(y, estimator=self, ensure_2d=False)

if not self.threshold:
raise ValueError("The `threshold` value cannot be `None`.")
Expand All @@ -119,7 +126,6 @@ def fit(self, X, y=None):
)
self.umap_.fit(X, y)
self.offset_ = -self.threshold
self.n_features_in_ = X.shape[1]
return self

def difference(self, X):
Expand Down
4 changes: 2 additions & 2 deletions sklego/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import numpy as np
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted, check_random_state, check_X_y
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted, check_random_state

from ._sklearn_compat import _check_n_features, check_array
from ._sklearn_compat import _check_n_features, check_array, check_X_y


class RandomRegressor(RegressorMixin, BaseEstimator):
Expand Down
4 changes: 2 additions & 2 deletions sklego/feature_selection/mrmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from sklearn.base import BaseEstimator
from sklearn.feature_selection import f_classif, f_regression
from sklearn.feature_selection._base import SelectorMixin
from sklearn.utils.validation import check_is_fitted, check_X_y
from sklearn.utils.validation import check_is_fitted

from sklego._sklearn_compat import _check_n_features
from sklego._sklearn_compat import _check_n_features, check_X_y


def _redundancy_pearson(X, selected, left):
Expand Down
4 changes: 2 additions & 2 deletions sklego/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from sklearn.linear_model._base import LinearClassifierMixin
from sklearn.multiclass import OneVsOneClassifier, OneVsRestClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.validation import FLOAT_DTYPES, _check_sample_weight, check_is_fitted, check_X_y, column_or_1d
from sklearn.utils.validation import FLOAT_DTYPES, _check_sample_weight, check_is_fitted, column_or_1d

from sklego._sklearn_compat import _check_n_features, check_array
from sklego._sklearn_compat import _check_n_features, check_array, check_X_y


class LowessRegression(RegressorMixin, BaseEstimator):
Expand Down
4 changes: 2 additions & 2 deletions sklego/meta/confusion_balancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from sklearn.base import BaseEstimator, ClassifierMixin, MetaEstimatorMixin
from sklearn.metrics import confusion_matrix
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted, check_X_y
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted

from sklego._sklearn_compat import _check_n_features, check_array
from sklego._sklearn_compat import _check_n_features, check_array, check_X_y
from sklego.base import ProbabilisticClassifier


Expand Down
4 changes: 2 additions & 2 deletions sklego/meta/decay_estimator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from sklearn import clone
from sklearn.base import BaseEstimator, MetaEstimatorMixin
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted, check_X_y
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted

from sklego._sklearn_compat import _check_n_features
from sklego._sklearn_compat import _check_n_features, check_X_y
from sklego.meta._decay_utils import exponential_decay, linear_decay, sigmoid_decay, stepwise_decay


Expand Down
4 changes: 2 additions & 2 deletions sklego/meta/estimator_transformer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from sklearn import clone
from sklearn.base import BaseEstimator, MetaEstimatorMixin, TransformerMixin
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted, check_X_y
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted

from sklego._sklearn_compat import _check_n_features, check_array
from sklego._sklearn_compat import _check_n_features, check_array, check_X_y


class EstimatorTransformer(TransformerMixin, MetaEstimatorMixin, BaseEstimator):
Expand Down
4 changes: 2 additions & 2 deletions sklego/meta/ordinal_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from sklearn import clone
from sklearn.base import BaseEstimator, ClassifierMixin, MetaEstimatorMixin, MultiOutputMixin, is_classifier
from sklearn.calibration import CalibratedClassifierCV
from sklearn.utils.validation import check_is_fitted, check_X_y
from sklearn.utils.validation import check_is_fitted

from sklego._sklearn_compat import _check_n_features, check_array
from sklego._sklearn_compat import _check_n_features, check_array, check_X_y


class OrdinalClassifier(MultiOutputMixin, ClassifierMixin, MetaEstimatorMixin, BaseEstimator):
Expand Down
4 changes: 2 additions & 2 deletions sklego/meta/outlier_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from sklearn import clone
from sklearn.base import BaseEstimator, ClassifierMixin, MetaEstimatorMixin
from sklearn.calibration import _SigmoidCalibration
from sklearn.utils.validation import check_is_fitted, check_X_y
from sklearn.utils.validation import check_is_fitted

from sklego._sklearn_compat import check_array
from sklego._sklearn_compat import check_array, check_X_y
from sklego.base import OutlierModel


Expand Down
4 changes: 2 additions & 2 deletions sklego/meta/subjective_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from sklearn.base import BaseEstimator, ClassifierMixin, MetaEstimatorMixin
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import normalize
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted, check_X_y
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted

from sklego._sklearn_compat import _check_n_features, check_array
from sklego._sklearn_compat import _check_n_features, check_array, check_X_y


class SubjectiveClassifier(ClassifierMixin, MetaEstimatorMixin, BaseEstimator):
Expand Down
4 changes: 2 additions & 2 deletions sklego/meta/thresholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from sklearn import clone
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.exceptions import NotFittedError
from sklearn.utils.validation import _check_sample_weight, check_is_fitted, check_X_y
from sklearn.utils.validation import _check_sample_weight, check_is_fitted

from sklego._sklearn_compat import _check_n_features, check_array, type_of_target
from sklego._sklearn_compat import _check_n_features, check_array, check_X_y, type_of_target
from sklego.base import ProbabilisticClassifier


Expand Down
4 changes: 2 additions & 2 deletions sklego/meta/zero_inflated_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from sklearn.base import BaseEstimator, MetaEstimatorMixin, RegressorMixin, clone, is_classifier, is_regressor
from sklearn.exceptions import NotFittedError
from sklearn.utils.metaestimators import available_if
from sklearn.utils.validation import _check_sample_weight, check_is_fitted, check_X_y
from sklearn.utils.validation import _check_sample_weight, check_is_fitted

from sklego._sklearn_compat import _check_n_features, check_array
from sklego._sklearn_compat import _check_n_features, check_array, check_X_y


class ZeroInflatedRegressor(RegressorMixin, MetaEstimatorMixin, BaseEstimator):
Expand Down
4 changes: 2 additions & 2 deletions sklego/mixture/bayesian_gmm_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.mixture import BayesianGaussianMixture
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted, check_X_y
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted

from sklego._sklearn_compat import _check_n_features, check_array
from sklego._sklearn_compat import _check_n_features, check_array, check_X_y


class BayesianGMMClassifier(ClassifierMixin, BaseEstimator):
Expand Down
4 changes: 2 additions & 2 deletions sklego/mixture/gmm_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.mixture import GaussianMixture
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted, check_X_y
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted

from sklego._sklearn_compat import _check_n_features, check_array
from sklego._sklearn_compat import _check_n_features, check_array, check_X_y


class GMMClassifier(ClassifierMixin, BaseEstimator):
Expand Down
4 changes: 2 additions & 2 deletions sklego/naive_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.mixture import BayesianGaussianMixture, GaussianMixture
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted, check_X_y
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted

from sklego._sklearn_compat import _check_n_features, check_array
from sklego._sklearn_compat import _check_n_features, check_array, check_X_y


class GaussianMixtureNB(ClassifierMixin, BaseEstimator):
Expand Down
4 changes: 2 additions & 2 deletions sklego/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.neighbors import KernelDensity
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted, check_X_y
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted

from sklego._sklearn_compat import _check_n_features, check_array
from sklego._sklearn_compat import _check_n_features, check_array, check_X_y


class BayesianKernelDensityClassifier(ClassifierMixin, BaseEstimator):
Expand Down
4 changes: 2 additions & 2 deletions sklego/preprocessing/intervalencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_is_fitted, check_X_y
from sklearn.utils.validation import check_is_fitted

from sklego._sklearn_compat import _check_n_features, check_array
from sklego._sklearn_compat import _check_n_features, check_array, check_X_y


def _mk_monotonic_average(xs, ys, intervals, method="increasing", **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions sklego/preprocessing/randomadder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from warnings import warn

from sklearn.base import BaseEstimator
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted, check_random_state, check_X_y
from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted, check_random_state

from sklego._sklearn_compat import _check_n_features, check_array
from sklego._sklearn_compat import _check_n_features, check_array, check_X_y
from sklego.common import TrainOnlyTransformerMixin


Expand Down
1 change: 1 addition & 0 deletions tests/test_estimators/test_quantile_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def _create_dataset(coefs, intercept, noise=0.0):
)
def test_sklearn_compatible_estimator(estimator, check):
if check.func.__name__ in {
"check_sample_weights_invariance",
"check_sample_weight_equivalence_on_dense_data",
}:
pytest.skip()
Expand Down

0 comments on commit 54e7664

Please sign in to comment.