Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create lax.zeta with native HLO lowering #17144

Merged
merged 1 commit into from
Aug 18, 2023
Merged

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Aug 16, 2023

Quick benchmarks:

import jax
import jax.numpy as jnp
import numpy as np

N = 10000
x = jnp.array(1 + np.random.rand(N) * 10)
y = jnp.array(np.random.rand(N))

zeta = jax.jit(jax.scipy.special.zeta)

%time zeta(x, y).block_until_ready()
%timeit zeta(x, y).block_until_ready()

On main branch:

CPU times: user 193 ms, sys: 287 ms, total: 480 ms
Wall time: 64.8 ms
965 µs ± 3.48 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

On zeta branch:

CPU times: user 272 ms, sys: 27.9 ms, total: 300 ms
Wall time: 38.3 ms
256 µs ± 1.07 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

The new implementation is about 2x faster in tracing/compilation, and 3-4x faster at runtime.

@jakevdp jakevdp added the pull ready Ready for copybara import and testing label Aug 16, 2023
@jakevdp jakevdp self-assigned this Aug 16, 2023
@jakevdp jakevdp marked this pull request as draft August 16, 2023 20:21
x = np.array([1e5, 1e19, 1e10], dtype=np.float32)
q = np.array([1., 40., 30.], dtype=np.float32)
self.assertAllClose(np.array([1., 0., 0.], dtype=np.float32), lsp_special.zeta(x, q))

Copy link
Collaborator Author

@jakevdp jakevdp Aug 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I removed this test because it fails for the native XLA zeta implementation. As noted in the reporting issue (#3758), both scipy.special.zeta and tf.math.zeta also fail this test; with this in mind I think it's not too big an issue if JAX's implementation does as well.

@jakevdp
Copy link
Collaborator Author

jakevdp commented Aug 16, 2023

CC/ @fehiepsi – curious about your thoughts here as the contributor of the original zeta implementation.

@jakevdp jakevdp requested a review from froystig August 16, 2023 22:15
Copy link
Member

@froystig froystig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice use of custom JVP! Good idea asking @fehiepsi for any thoughts as well.

@jakevdp jakevdp marked this pull request as ready for review August 16, 2023 23:38
@fehiepsi
Copy link
Contributor

LGTM! Adding @srvasude who raised the issue #3758.

@jakevdp
Copy link
Collaborator Author

jakevdp commented Aug 17, 2023

Another idea I thought of: we could keep the iterative implementation available via an optional impl keyword argument to jax.scipy.special.zeta in case anyone is depending on corner cases of the previous algorithm.

@fehiepsi - one question: I noticed that the series solution for zeta supports complex inputs (untested because the scipy version only supports real inputs). Do you know whether the algorithm produces the correct result in those cases?

@fehiepsi
Copy link
Contributor

It seems that the results are wrong for complex numbers. I tested against mpmath:

In [17]: mpmath.zeta(2. + 0.2j, 3.)
Out[17]: mpc(real='0.35908544306418988', imag='-0.14527069353550556')

In [18]: special.zeta(2. + 0.2j, 3.)
Out[18]: Array(0.03761286-0.03729502j, dtype=complex64)

@jakevdp
Copy link
Collaborator Author

jakevdp commented Aug 18, 2023

Thanks - that's good to know that replacing it with a real-only function isn't an issue!

@copybara-service copybara-service bot merged commit 209b6b0 into jax-ml:main Aug 18, 2023
@jakevdp jakevdp deleted the zeta branch August 18, 2023 18:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants