-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create lax.zeta with native HLO lowering #17144
Conversation
x = np.array([1e5, 1e19, 1e10], dtype=np.float32) | ||
q = np.array([1., 40., 30.], dtype=np.float32) | ||
self.assertAllClose(np.array([1., 0., 0.], dtype=np.float32), lsp_special.zeta(x, q)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: I removed this test because it fails for the native XLA zeta
implementation. As noted in the reporting issue (#3758), both scipy.special.zeta
and tf.math.zeta
also fail this test; with this in mind I think it's not too big an issue if JAX's implementation does as well.
CC/ @fehiepsi – curious about your thoughts here as the contributor of the original zeta implementation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice use of custom JVP! Good idea asking @fehiepsi for any thoughts as well.
Another idea I thought of: we could keep the iterative implementation available via an optional @fehiepsi - one question: I noticed that the series solution for |
It seems that the results are wrong for complex numbers. I tested against mpmath:
|
Thanks - that's good to know that replacing it with a real-only function isn't an issue! |
Quick benchmarks:
On
main
branch:On
zeta
branch:The new implementation is about 2x faster in tracing/compilation, and 3-4x faster at runtime.