-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to use jax.hessian? #173
Comments
This is a great question, and one that we've been discussing! One problem is that JAX gave you very unhelpful error messages for this use case. We need to improve those. Ultimately, we need the error messages to communicate that the Here's the heart of the issue: given the example For arrays, there's a clear answer because it's easy to reason about adding axes. If a function But if I really mean that as a question! Do you have a clear sense of what would make sense from your perspective? If there's a clear way to handle cases like these, we'll implement it! The answer may ultimately be that we can only represent Hessians for array-input array-output functions, in which case flattening a container-input function into an array-input one, then working with the Hessian of the flattened function, may be the right answer. JAX could provide utilities for doing that (we had nice ones in Autograd, with a slick implementation). In the meantime, flattening things yourself probably makes the most sense, unless you want to wait a few days for JAX to gain some flattening utilities. An alternative might be to use a Lanczos iteration together with a Hessian-vector product, which you can express easily in JAX. Then you'd only have to deal with vectors, rather than having to worry about how to represent matrices, and we know how to handle tuples/lists/dicts there. (But Lanczos would only be accurate for extremal eigenvalues, and its numerical effectiveness would depend on the conditioning of the Hessian, whereas direct eigenvalue algorithms would be independent of the conditioning.) |
I'm exploring JAX as a more straightforward way to do higher order derivatives. Right now I am using tensorflow (
Flattening is currently the most straightforward way to do this in pytorch (which also has nice flattening utilities) say when constructing fisher information matrices in RL. Unfortunately, flattening in tensorflow graph mode cannot be used with I think the most useful way for hessians of arbitrary matrices or compositions on matrices would be:
I really want to compute the eigenspectrum (density of eigenvalues) and associated eigendirections so accuracy is important. |
We had some conversations about this, and I think our plan is to:
|
I'm glad #201 didn't close this issue, because it seems to have broken |
I'm using this snippet from the README
combined with the "getting started with pytorch data loaders" colab. How do I compute and use the hessian of this neural network? So far I have tried:
hessian_loss = hessian(loss)(params, x, y)
(TypeError: jacfwd() takes 2 positional arguments but 4 were given
)hessian_loss = hessian(loss)((params, x, y))
(takes a long time with a small network and then returnsAttributeError: 'PyLeaf' object has no attribute 'node_type'
)hessian_loss = hessian(lambda params_: loss(params_, x, y)(params)
(AttributeError: 'PyLeaf' object has no attribute 'node_type'
)All I really need is to compute the eigenvalues and eigenvectors of this Hessian.
Note:
repr(params) = [(DeviceArray{float32[10,784]}, DeviceArray{float32[10]})]
Some ideas that I have:
The text was updated successfully, but these errors were encountered: