Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNT use dataclass for IRLSData #881

Merged
merged 4 commits into from
Nov 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 27 additions & 49 deletions src/glum/_solvers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import functools
import time
import warnings
from typing import Optional, Union
from dataclasses import InitVar, dataclass
from typing import Any, Optional, Union

import numpy as np
from scipy import linalg, sparse
Expand Down Expand Up @@ -408,53 +409,33 @@ def _update(self, n_iter, iteration_runtime, cur_grad_norm):
self.t.update(0)


@dataclass
class IRLSData:
"""Store parameters for the IRLS optimizer."""

def __init__(
self,
X,
y: np.ndarray,
sample_weight: np.ndarray,
P1: Union[np.ndarray, sparse.spmatrix],
P2: Union[np.ndarray, sparse.spmatrix],
fit_intercept: bool,
family: ExponentialDispersionModel,
link: Link,
max_iter: int = 100,
max_inner_iter: int = 100000,
gradient_tol: Optional[float] = 1e-4,
step_size_tol: Optional[float] = 1e-4,
hessian_approx: float = 0.0,
fixed_inner_tol: Optional[tuple] = None,
selection="cyclic",
random_state=None,
offset: Optional[np.ndarray] = None,
lower_bounds: Optional[np.ndarray] = None,
upper_bounds: Optional[np.ndarray] = None,
verbose: bool = False,
):
self.X = X
self.y = y
self.sample_weight = sample_weight
self.P1 = P1

# Note: we already set P2 = l2*P2, P1 = l1*P1
# Note: we already symmetrized P2 = 1/2 (P2 + P2')
self.P2 = P2

self.fit_intercept = fit_intercept
self.family = family
self.link = link
self.max_iter = max_iter
self.max_inner_iter = max_inner_iter
self.gradient_tol = gradient_tol
self.step_size_tol = step_size_tol
self.hessian_approx = hessian_approx
self.fixed_inner_tol = fixed_inner_tol
self.selection = selection
self.random_state = random_state
self.offset = offset
X: Any
y: np.ndarray
sample_weight: np.ndarray
# Note: we already set P2 = l2*P2, P1 = l1*P1 and symmetrized P2 = 1/2 (P2 + P2')
P1: Union[np.ndarray, sparse.spmatrix]
P2: Union[np.ndarray, sparse.spmatrix]
fit_intercept: bool
family: ExponentialDispersionModel
link: Link
max_iter: int = 100
max_inner_iter: int = 100000
gradient_tol: Optional[float] = 1e-4
step_size_tol: Optional[float] = 1e-4
hessian_approx: float = 0.0
fixed_inner_tol: Optional[tuple] = None
selection: str = "cyclic"
random_state: Union[None, int, np.random.RandomState] = None
offset: Optional[np.ndarray] = None
lower_bounds: InitVar[Optional[np.ndarray]] = None
upper_bounds: InitVar[Optional[np.ndarray]] = None
verbose: bool = False

def __post_init__(self, lower_bounds, upper_bounds):
self.has_lower_bounds, self._lower_bounds = _setup_bounds(
lower_bounds, self.X.dtype
)
Expand All @@ -463,11 +444,8 @@ def __init__(
)

self.intercept_offset = 1 if self.fit_intercept else 0
self.verbose = verbose

self._check_data()

def _check_data(self):
# Check data
if self.P2.ndim == 2:
self.P2 = check_array(self.P2, "csc", dtype=[np.float64, np.float32])

Expand Down
Loading