Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scipy.special.zeta returning NaNs #17248

Closed
hongwanliu opened this issue Aug 23, 2023 · 6 comments
Closed

scipy.special.zeta returning NaNs #17248

hongwanliu opened this issue Aug 23, 2023 · 6 comments
Assignees
Labels
bug Something isn't working

Comments

@hongwanliu
Copy link

hongwanliu commented Aug 23, 2023

Description

After a recent update, the zeta function is returning NaNs for negative Riemann zeta arguments:

from jax.scipy.special import zeta

zeta(-3, 1)

# Expected: Array(0.00833333, dtype=float64)

This was working in jax v0.4.14, but broken in jax v0.4.15.dev20230823.

Relatedly, even in v0.4.14, zeta(n, 1) for any n <= -9 gives wildly incorrect results (in fact zeta(-n, 1) for all n even should be zero, but scipy.zeta returns increasingly large floats as n increases, giving e.g. 3.72529030e-09 for zeta(-6, 1)).

What jax/jaxlib version are you using?

jax v0.4.15.dev20230823

Which accelerator(s) are you using?

CPU

@hongwanliu hongwanliu added the bug Something isn't working label Aug 23, 2023
@jakevdp
Copy link
Collaborator

jakevdp commented Aug 23, 2023

Hi - thanks for the report! This is due to #17144, which changed jax.scipy.special.zeta to use the native HLO lowering. The results are faster and more accurate than the previous implementation, but it comes with changes to the range of values supported. The new jax.scipy.special.zeta is more similar to scipy.special.zeta, which also returns nan in this case:

>>> from scipy.special import zeta
>>> zeta(-3, 1)
nan

I'm not sure what the best fix is here, because we want to make the faster and more accurate native zeta function available: it seems like the best default in most cases, but it obviously doesn't serve your use case well.

What do you think?

@jakevdp jakevdp self-assigned this Aug 23, 2023
@hongwanliu
Copy link
Author

I see. I really only need zeta(-n, 1) for integer n, which are equivalent to the Bernoulli numbers. These are implemented in scipy.special, but not yet in jax.scipy.special. The function for generating Bernoulli numbers in scipy.special is here: https://github.com/scipy/scipy/blob/v1.11.2/scipy/special/specfun/specfun.f, and uses the Chowla and Hartung algorithm described here: https://math.stackexchange.com/questions/2844290/what-is-the-simplest-way-to-get-bernoulli-numbers.

Any chance something like that can be implemented easily?

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 23, 2023

I see - I think something like that would be possible to implement, simliar to the previous expansion-based implementation of zeta.

One reason I worry about zeta(x, 1) for x < 0 is that I don't think the previous series expansion is guaranteed to work in that regime – it looks like it does work for negative integers x, but we don't have any testing outside the domain it's designed for (essentially x > 1 I believe).

A benefit of the new implementation is that it's more explicit about where its output is valid.

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 23, 2023

If we were to implement a JAX-compatible version of scipy.special.bernoulli, would that serve your needs here?

@hongwanliu
Copy link
Author

Yes! I currently just have a table of them stored in my code which does the job, but it would be helpful generally to have them, since they appear in a lot of special functions (such as polylogs, which I’m trying to implement).

Thanks for being so responsive!

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 8, 2023

I think this can be closed, as the workaround is implemented and zeta is behaving as expected.

@jakevdp jakevdp closed this as completed Nov 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants