Gauss-Newton Vector Products, linearizing once #24940
-
Dear All, I would like to ask for advice, on how to accelerate the computation of GGN-vector products. I wrote what I believe should be a reasonably efficient implementation of a GGN-vector product using one jvp and one vjp. My understanding is that my code linearizes the function twice, once for jvp, and then again for vjp (although the whole code is JIT compiled, so maybe XLA is able to reuse all the jacobians under the hood). However the autodial cookbook suggests that the GGN-vector product can be accelerated by linearizing the function just once (reusing function jacobians from jvp computation to calculate vjp). As far as I understand, I can use Any advice would be super welcome! Best, |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 5 replies
-
Yes,you can linearize or manually implement it. In this line of code: _, JtV = jax.jvp(lambda th: E_jacob(x, th, rec, params), (theta_vec,), (theta_vec,))
_, e_vjp = jax.vjp(lambda th: E_jacob(x, th, rec, params), theta_vec)
Gv = e_vjp(HyJtV)[0] change to f_val, f_jvp = jax.linearize(lambda th: E_jacob(x, th, rec, params), theta_vec)
JtV = f_jvp(gradient)
def f_vjp(vec):
return f_jvp(vec) # Transposing JVP --> simulate VJP
Gv = f_vjp(HyJtV) With this the Jacobian will be compute by |
Beta Was this translation helpful? Give feedback.
-
Thank you
judith valenzuela
El lun, 18 de nov de 2024 a la(s) 7:53 p. m., Howard ***@***.***> escribió:
Yes,you can linearize or manually implement it.
In this line of code:
_, JtV = jax.jvp(lambda th: E_jacob(x, th, rec, params), (theta_vec,), (theta_vec,))
_, e_vjp = jax.vjp(lambda th: E_jacob(x, th, rec, params), theta_vec)
Gv = e_vjp(HyJtV)[0]
change to
f_val, f_jvp = jax.linearize(lambda th: E_jacob(x, th, rec, params), theta_vec)
JtV = f_jvp(gradient)
def f_vjp(vec):
return f_jvp(vec) # Transposing JVP --> simulate VJP
Gv = f_vjp(HyJtV)
With this the Jacobian will be compute by jax.linearize is reused, saving computation and speeding up the code.
—
Reply to this email directly, view it on GitHub, or unsubscribe.
You are receiving this because you are subscribed to this thread.Message ID: ***@***.***>
|
Beta Was this translation helpful? Give feedback.
-
You want |
Beta Was this translation helpful? Give feedback.
You want
jax.linear_transpose
:)