From 990ca835612b63b73a7d657f7139e83b10c93ff5 Mon Sep 17 00:00:00 2001 From: Martin Stancsics Date: Mon, 9 Sep 2024 10:43:23 +0200 Subject: [PATCH] Follow scikit-learn API changes --- src/glum/_glm.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/glum/_glm.py b/src/glum/_glm.py index dde0bfbc..2ace19a8 100644 --- a/src/glum/_glm.py +++ b/src/glum/_glm.py @@ -27,6 +27,7 @@ import scipy.sparse as sps import scipy.sparse.linalg as splinalg import sklearn as skl +import sklearn.utils.validation import tabmat as tm from formulaic import Formula, FormulaSpec from formulaic.parser import DefaultFormulaParser @@ -43,6 +44,17 @@ column_or_1d, ) +if hasattr(sklearn.utils.validation, "validate_data"): + validate_data = sklearn.utils.validation.validate_data +else: + validate_data = BaseEstimator._validate_data + +if hasattr(sklearn.utils.validation, "_check_n_features"): + _check_n_features = sklearn.utils.validation._check_n_features +else: + _check_n_features = BaseEstimator._check_n_features + + from ._distribution import ( BinomialDistribution, ExponentialDispersionModel, @@ -2506,9 +2518,10 @@ def _set_up_and_check_fit_args( drop_first=getattr(self, "drop_first", False), **{keyword_finiteness: force_all_finite}, ) - self._check_n_features(X, reset=True) + _check_n_features(self, X, reset=True) else: - X, y = self._validate_data( + X, y = validate_data( + self, X, y, ensure_2d=True,