Skip to content

Commit

Permalink
Follow scikit-learn API changes
Browse files Browse the repository at this point in the history
  • Loading branch information
stanmart committed Sep 9, 2024
1 parent 80b3a2d commit 990ca83
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions src/glum/_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import scipy.sparse as sps
import scipy.sparse.linalg as splinalg
import sklearn as skl
import sklearn.utils.validation
import tabmat as tm
from formulaic import Formula, FormulaSpec
from formulaic.parser import DefaultFormulaParser
Expand All @@ -43,6 +44,17 @@
column_or_1d,
)

if hasattr(sklearn.utils.validation, "validate_data"):
validate_data = sklearn.utils.validation.validate_data
else:
validate_data = BaseEstimator._validate_data

if hasattr(sklearn.utils.validation, "_check_n_features"):
_check_n_features = sklearn.utils.validation._check_n_features
else:
_check_n_features = BaseEstimator._check_n_features


from ._distribution import (
BinomialDistribution,
ExponentialDispersionModel,
Expand Down Expand Up @@ -2506,9 +2518,10 @@ def _set_up_and_check_fit_args(
drop_first=getattr(self, "drop_first", False),
**{keyword_finiteness: force_all_finite},
)
self._check_n_features(X, reset=True)
_check_n_features(self, X, reset=True)
else:
X, y = self._validate_data(
X, y = validate_data(
self,
X,
y,
ensure_2d=True,
Expand Down

0 comments on commit 990ca83

Please sign in to comment.