Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dtype dependent precision #844

Merged
merged 12 commits into from
Nov 8, 2024
8 changes: 6 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
Changelog
=========

3.0.3 - unreleased
3.1.0 - unreleased
------------------

**Bug fix:
**Bug fix:**

- Fixed a bug where :meth:`glum.GeneralizedLinearRegressor.fit` would raise a ``dtype`` mismatch error if fit with ``alpha_search=True``.

**Other changes:**

- Use data type (``float64`` or ``float32``) dependent precision in solvers.

3.0.2 - 2024-06-25
------------------

Expand Down
5 changes: 3 additions & 2 deletions src/glum/_cd_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def enet_coordinate_descent_gram(int[::1] active_set,
bint has_lower_bounds,
floating[:] lower_bounds,
bint has_upper_bounds,
floating[:] upper_bounds):
floating[:] upper_bounds,
floating eps):
"""Cython version of the coordinate descent algorithm
for Elastic-Net regression
We minimize
Expand Down Expand Up @@ -162,7 +163,7 @@ def enet_coordinate_descent_gram(int[::1] active_set,
else:
P1_ii = P1[ii - intercept]

if Q[active_set_ii, active_set_ii] == 0.0:
if Q[active_set_ii, active_set_ii] <= eps:
continue

w_ii = w[ii] # Store previous value
Expand Down
2 changes: 1 addition & 1 deletion src/glum/_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def _one_over_var_inf_to_val(arr: np.ndarray, val: float) -> np.ndarray:

If values are zeros, return val.
"""
zeros = np.where(np.abs(arr) < 1e-7)
zeros = np.where(np.abs(arr) < np.sqrt(np.finfo(arr.dtype).eps))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mlondschien, what was the logic for this change? It's a slightly stricter criterion for float64 than before, which is causing testing failures downstream.

Copy link
Contributor Author

@mlondschien mlondschien Nov 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In [4]:  np.sqrt(np.finfo(np.float32).eps)
Out[4]: np.float32(0.00034526698)

In [5]:  np.sqrt(np.finfo(np.float64).eps)
Out[5]: np.float64(1.4901161193847656e-08)

None. I thought that's where the 1e-7 came from. I don't realy know what would be "good" values here. Just that 1e-7 is too strict for float32. Sorry!

I guess zeros = np.where(np.abs(arr) < np.sqrt(np.finfo(arr.dtype).eps) / 10) should do the trick 🤷 .

Essentially we're doing

var = 0 - 0
std = np.sqrt(var)
normalized_values = values * _one_over_var_inf_to_val(std, 1)

with np.errstate(divide="ignore"):
one_over = 1 / arr
one_over[zeros] = val
Expand Down
11 changes: 9 additions & 2 deletions src/glum/_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def _cd_solver(state, data, active_hessian):
data._lower_bounds,
data.has_upper_bounds,
data._upper_bounds,
np.finfo(state.coef.dtype).eps * 16,
)
return new_coef - state.coef, n_cycles

Expand Down Expand Up @@ -759,7 +760,8 @@ def line_search(state: IRLSState, data: IRLSData, d: np.ndarray):
"""
# line search parameters
(beta, sigma) = (0.5, 0.0001)
eps = 16 * np.finfo(state.obj_val.dtype).eps # type: ignore
# state.obj_val is np.float64, even if coef is np.float32
mlondschien marked this conversation as resolved.
Show resolved Hide resolved
eps = 16 * np.finfo(state.coef.dtype).eps # type: ignore

# line search by sequence beta^k, k=0, 1, ..
# F(w + lambda d) - F(w) <= lambda * bound
Expand Down Expand Up @@ -792,7 +794,12 @@ def line_search(state: IRLSState, data: IRLSData, d: np.ndarray):
)
# 1. Check Armijo / sufficient decrease condition.
loss_improvement = obj_val_wd - state.obj_val
if mu_wd.max() < 1e43 and loss_improvement <= factor * bound:
if mu_wd.dtype == np.float32:
mlondschien marked this conversation as resolved.
Show resolved Hide resolved
large_number = 1e30
else:
large_number = 1e43

if mu_wd.max() < large_number and loss_improvement <= factor * bound:
break
# 2. Deal with relative loss differences around machine precision.
tiny_loss = np.abs(state.obj_val * eps) # type: ignore
Expand Down
Loading