Replies: 1 comment 2 replies
-
I managed to write a JIT-able version of a solver based on Laguerre's method by replacing the I've done some reading on propagating gradients through a solver and if I understand it correctly, the implicit function theorem tells me that to compute the gradient of the roots z^* with respect to model parameters (polynomial coefficients), I only need to evaluate the first derivative of my complex polynomial p(z) at the roots z^* which is just a 4th order complex polynomial. I can do this by either writing a custom vjp function as in this tutorial or I can use the more convenient I'm still confused about how to use |
Beta Was this translation helpful? Give feedback.
-
Hi all! I'm trying to implement an astrophysics model in JAX. The central part of the model involves solving for the roots of a 5th order complex polynomial and I need to do that for a large grid of complex numbers in parallel on a GPU.
jax.numpy.roots
works but it relies onnp.linalg.eig
which isn't supported on the GPU (see #1259). Is there an alternative algorithm which would work on a GPU?I tried implementing a simple solver using Laguerre's method in JAX but I'm not sure how to implement the
tangent_solve
method required bylax.custom_root
, also, my naive implementation isn't jit-able so it's orders of magnitude slower thanjax.numpy.roots
.Any help would be much appreciated :).
Beta Was this translation helpful? Give feedback.
All reactions