-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use dtype dependent precision #844
Conversation
It would be very cool to have float32 support that "just works". I would expect that you will run into a couple more issues. In 653d6f1 I'm now running the test suite on a float32 dataset. This actually looks pretty good, it's just that on the inference side, we're still expecting doubles in a lot of places.
|
This is an example fix for one of the mistakes causing the errors on @jtilly's branch. --- a/src/glum/_glm.py
+++ b/src/glum/_glm.py
@@ -2128,7 +2128,7 @@ class GeneralizedLinearRegressorBase(BaseEstimator, RegressorMixin):
)
if (
- np.linalg.cond(_safe_toarray(X.sandwich(np.ones(X.shape[0]))))
+ np.linalg.cond(_safe_toarray(X.sandwich(np.ones(X.shape[0], dtype=X.dtype))))
> 1 / sys.float_info.epsilon**2
):
raise np.linalg.LinAlgError( There are a bunch of similar ones in the functions used for calculating the covariance matrix. |
I think there are also quite some "Kinderkrankheiten" that are not covered by the tests. E.g., if run on "real data",
and
probably due to fixed convergence tolerances. Setting |
Yes, this is a bit of a rabbit hole. We looked into this when we built I think we'll also have to do a bit of work in
Works fine with Edit: reproducer here: https://github.com/Quantco/tabmat/compare/test-float32?expand=1 |
I'm having issues finding an |
Two questions about the convergence criteria:
Do you have a reference on how to improve convergence? For reasonable
|
…s is different between the fit and predict methods. (Quantco#848) * Check number of features when predicting * Add changelog entry
6e05960
to
758ec9d
Compare
Co-authored-by: Luca Bittarello <15511539+lbittarello@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would make sense to extend the test suite to float32s before merging this (see 653d6f1). There are still a few places where we are mixing dtypes, and the new tests would uncover many of those. Or does that belong to a separate PR?
I am unsure whether the tests would have caught anything here. The solvers would have simply run until |
Yeah, that's true. Let's fix the remaining dtype-related bugs elsewhere then. |
@MarcAntoineSchmidtQC can we merge this? |
@stanmart, ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd love to see how well the optimization converges using float32
s, but I'm also okay with adding such tests later.
See #872 for an example for "bad convergence" with |
Can we merge this and #865 and then work on the convergence issues. For me, hotfixing Quantco/tabmat#408 locally already gave me a huge improvement in convergence. |
@@ -452,7 +452,7 @@ def _one_over_var_inf_to_val(arr: np.ndarray, val: float) -> np.ndarray: | |||
|
|||
If values are zeros, return val. | |||
""" | |||
zeros = np.where(np.abs(arr) < 1e-7) | |||
zeros = np.where(np.abs(arr) < np.sqrt(np.finfo(arr.dtype).eps)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mlondschien, what was the logic for this change? It's a slightly stricter criterion for float64
than before, which is causing testing failures downstream.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In [4]: np.sqrt(np.finfo(np.float32).eps)
Out[4]: np.float32(0.00034526698)
In [5]: np.sqrt(np.finfo(np.float64).eps)
Out[5]: np.float64(1.4901161193847656e-08)
None. I thought that's where the 1e-7 came from. I don't realy know what would be "good" values here. Just that 1e-7 is too strict for float32
. Sorry!
I guess zeros = np.where(np.abs(arr) < np.sqrt(np.finfo(arr.dtype).eps) / 10)
should do the trick 🤷 .
Essentially we're doing
var = 0 - 0
std = np.sqrt(var)
normalized_values = values * _one_over_var_inf_to_val(std, 1)
xref #843