Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solver - add a wrapper for scipy L-BFGS solver #165

Merged
merged 13 commits into from
Jun 12, 2023
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Penalties
L0_5
L1
L1_plus_L2
L2
L2_3
MCPenalty
WeightedL1
Expand Down Expand Up @@ -78,6 +79,7 @@ Solvers
GramCD
GroupBCD
GroupProxNewton
LBFGS
MultiTaskBCD
ProxNewton

Expand Down
3 changes: 3 additions & 0 deletions skglm/datafits/single_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def value(self, y, w, Xw):
def gradient_scalar(self, X, y, w, Xw, j):
return (- X[:, j] @ (y * sigmoid(- y * Xw))) / len(y)

def gradient(self, X, y, Xw):
return X.T @ self.raw_grad(y, Xw)

def full_grad_sparse(
self, X_data, X_indptr, X_indices, y, Xw):
n_features = X_indptr.shape[0] - 1
Expand Down
4 changes: 2 additions & 2 deletions skglm/penalties/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .base import BasePenalty
from .separable import (
L1_plus_L2, L0_5, L1, L2_3, MCPenalty, SCAD, WeightedL1, IndicatorBox,
L1_plus_L2, L0_5, L1, L2, L2_3, MCPenalty, SCAD, WeightedL1, IndicatorBox,
PositiveConstraint
)
from .block_separable import (
Expand All @@ -12,6 +12,6 @@

__all__ = [
BasePenalty,
L1_plus_L2, L0_5, L1, L2_3, MCPenalty, SCAD, WeightedL1, IndicatorBox,
L1_plus_L2, L0_5, L1, L2, L2_3, MCPenalty, SCAD, WeightedL1, IndicatorBox,
PositiveConstraint, L2_05, L2_1, BlockMCPenalty, BlockSCAD, WeightedGroupL2, SLOPE
]
31 changes: 31 additions & 0 deletions skglm/penalties/separable.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,3 +557,34 @@ def is_penalized(self, n_features):
def generalized_support(self, w):
"""Return a mask with non-zero coefficients."""
return w != 0


class L2(BasePenalty):
r""":math:`ell_2` penalty.

The penalty reads

.. math::

\alpha / 2 ||w||_2^2
"""

def __init__(self, alpha):
self.alpha = alpha

def get_spec(self):
spec = (
('alpha', float64),
)
return spec

def params_to_dict(self):
return dict(alpha=self.alpha)

def value(self, w):
"""Compute the value of the L2 penalty."""
return self.alpha * (w ** 2).sum() / 2

def gradient(self, w):
"""Compute the gradient of the L2 penalty."""
return self.alpha * w
3 changes: 2 additions & 1 deletion skglm/solvers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from .multitask_bcd import MultiTaskBCD
from .prox_newton import ProxNewton
from .group_prox_newton import GroupProxNewton
from .lbfgs import LBFGS


__all__ = [AndersonCD, BaseSolver, FISTA, GramCD, GroupBCD, MultiTaskBCD, ProxNewton,
GroupProxNewton]
GroupProxNewton, LBFGS]
92 changes: 92 additions & 0 deletions skglm/solvers/lbfgs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import warnings
from sklearn.exceptions import ConvergenceWarning

import numpy as np
import scipy.optimize
from numpy.linalg import norm

from skglm.solvers import BaseSolver


class LBFGS(BaseSolver):
"""A wrapper for scipy L-BFGS solver.

Refer to `scipy L-BFGS-B <https://docs.scipy.org/doc/scipy/reference/optimize.
minimize-lbfgsb.html#optimize-minimize-lbfgsb>`_ documentation for details.

Parameters
----------
max_iter : int, default 20
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is quite small, is it enough on a 1000 x 1000 dataset for example?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

50?

Maximum number of iterations.

tol : float, default 1e-4
Tolerance for convergence.

verbose : bool, default False
Amount of verbosity. 0/False is silent.
"""

def __init__(self, max_iter=50, tol=1e-4, verbose=False):
self.max_iter = max_iter
self.tol = tol
self.verbose = verbose

def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):

def objective_function(w):
Xw = X @ w
datafit_value = datafit.value(y, w, Xw)
penalty_value = penalty.value(w)

return datafit_value + penalty_value

def jacobian_function(w):
Xw = X @ w
datafit_grad = datafit.gradient(X, y, Xw)
penalty_grad = penalty.gradient(w)

return datafit_grad + penalty_grad

def callback_post_iter(w_k):
# save p_obj
p_obj = objective_function(w_k)
p_objs_out.append(p_obj)

if self.verbose:
grad = jacobian_function(w_k)
stop_crit = norm(grad)

it = len(p_objs_out)
print(
f"Iteration {it}: {p_obj:.10f}, "
f"stopping crit: {stop_crit:.2e}"
)

n_features = X.shape[1]
w = np.zeros(n_features) if w_init is None else w_init
p_objs_out = []

result = scipy.optimize.minimize(
fun=objective_function,
jac=jacobian_function,
x0=w,
method="L-BFGS-B",
options=dict(
maxiter=self.max_iter,
gtol=self.tol
),
callback=callback_post_iter,
)

if not result.success:
warnings.warn(
f"`LBFGS` did not converge for tol={self.tol:.3e} "
f"and max_iter={self.max_iter}.\n"
"Consider increasing `max_iter` and/or `tol`.",
category=ConvergenceWarning
)

w = result.x
stop_crit = norm(result.jac)

return w, np.asarray(p_objs_out), stop_crit
40 changes: 40 additions & 0 deletions skglm/tests/test_lbfgs_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import numpy as np

from skglm.solvers import LBFGS
from skglm.penalties import L2
from skglm.datafits import Logistic

from sklearn.linear_model import LogisticRegression

from skglm.utils.data import make_correlated_data
from skglm.utils.jit_compilation import compiled_clone


def test_lbfgs_L2_logreg():
reg = 1.
n_samples, n_features = 50, 10

X, y, _ = make_correlated_data(
n_samples, n_features, random_state=0)
y = np.sign(y)

# fit L-BFGS
datafit = compiled_clone(Logistic())
penalty = compiled_clone(L2(reg))
w, *_ = LBFGS().solve(X, y, datafit, penalty)

# fit scikit learn
estimator = LogisticRegression(
penalty='l2',
C=1 / (n_samples * reg),
fit_intercept=False
)
estimator.fit(X, y)

np.testing.assert_allclose(
w, estimator.coef_.flatten(), atol=1e-4
)


if __name__ == "__main__":
pass