Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use jax.hessian? #173

Closed
zafarali opened this issue Dec 29, 2018 · 4 comments
Closed

How to use jax.hessian? #173

zafarali opened this issue Dec 29, 2018 · 4 comments
Assignees
Labels
enhancement New feature or request question Questions for the JAX team

Comments

@zafarali
Copy link
Contributor

I'm using this snippet from the README

from jax import jit, jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

combined with the "getting started with pytorch data loaders" colab. How do I compute and use the hessian of this neural network? So far I have tried:

  1. Naive: hessian_loss = hessian(loss)(params, x, y) (TypeError: jacfwd() takes 2 positional arguments but 4 were given)
  2. Naive2: hessian_loss = hessian(loss)((params, x, y)) (takes a long time with a small network and then returns AttributeError: 'PyLeaf' object has no attribute 'node_type')
  3. FromTheTests: hessian_loss = hessian(lambda params_: loss(params_, x, y)(params) (AttributeError: 'PyLeaf' object has no attribute 'node_type')

All I really need is to compute the eigenvalues and eigenvectors of this Hessian.

Note: repr(params) = [(DeviceArray{float32[10,784]}, DeviceArray{float32[10]})]

Some ideas that I have:

  1. Should I re-write the code to use only flat arrays?
@mattjj
Copy link
Collaborator

mattjj commented Dec 29, 2018

This is a great question, and one that we've been discussing!

One problem is that JAX gave you very unhelpful error messages for this use case. We need to improve those. Ultimately, we need the error messages to communicate that the hessian function, as with jacfwd and jacrev, only apply to array-input array-output functions (of one argument). In particular, they don't work for tuple/list/dict inputs, or with respect to multiple arguments, for the same reason in both cases.

Here's the heart of the issue: given the example params here (a one-element list of a pair of arrays), how would you want the value of hessian(lambda params: loss(params, x, y))(params) to be stored? More generally, given a tuple/list/dict argument, how should we represent the Hessian?

For arrays, there's a clear answer because it's easy to reason about adding axes. If a function fun takes arrays of shape (in_1, in_2, ..., in_n) to arrays of shape (out_1, out_2, ..., out_m), then it's reasonable for hessian(fun) to be a function that takes an array of shape (in_1, in_2, ..., in_n) to an array of shape (out_1, out_2, ..., out_m, in_1, in_2, ..., in_n, in_1, in_2, ..., in_n), though other conventions could be reasonable too. (As I wrote this, I got a sense of deja vu...)

But if fun is, say, a function that takes a tuple of scalars to a scalar, then what should hessian(fun) return? Some kind of nested tuple structure? How do we organize the nesting?

I really mean that as a question! Do you have a clear sense of what would make sense from your perspective? If there's a clear way to handle cases like these, we'll implement it!

The answer may ultimately be that we can only represent Hessians for array-input array-output functions, in which case flattening a container-input function into an array-input one, then working with the Hessian of the flattened function, may be the right answer. JAX could provide utilities for doing that (we had nice ones in Autograd, with a slick implementation).

In the meantime, flattening things yourself probably makes the most sense, unless you want to wait a few days for JAX to gain some flattening utilities. An alternative might be to use a Lanczos iteration together with a Hessian-vector product, which you can express easily in JAX. Then you'd only have to deal with vectors, rather than having to worry about how to represent matrices, and we know how to handle tuples/lists/dicts there. (But Lanczos would only be accurate for extremal eigenvalues, and its numerical effectiveness would depend on the conditioning of the Hessian, whereas direct eigenvalue algorithms would be independent of the conditioning.)

@mattjj mattjj self-assigned this Dec 29, 2018
@zafarali
Copy link
Contributor Author

zafarali commented Dec 30, 2018

I really mean that as a question! Do you have a clear sense of what would make sense from your perspective? If there's a clear way to handle cases like these, we'll implement it!

I'm exploring JAX as a more straightforward way to do higher order derivatives. Right now I am using tensorflow (tf.hessians) but it quickly becomes clunky and it doesn't work in eager mode. My real use case is to do some analysis on the eigenvalues and eigenvectors of a neural network with a few thousand parameters.

The answer may ultimately be that we can only represent Hessians for array-input array-output functions, in which case flattening a container-input function into an array-input one, then working with the Hessian of the flattened function, may be the right answer. JAX could provide utilities for doing that (we had nice ones in Autograd, with a slick implementation).

Flattening is currently the most straightforward way to do this in pytorch (which also has nice flattening utilities) say when constructing fisher information matrices in RL. Unfortunately, flattening in tensorflow graph mode cannot be used with tf.hessians unless the flattened version was used to predict the output of the network.

I think the most useful way for hessians of arbitrary matrices or compositions on matrices would be:

  1. Have weight matrices represented as [Input_i x Output_i] for i the index of the layer.
  2. Get the output of loss function by composing the weight matrices and required non-linearities.
  3. Flatten the matrices into one long array and (somehow) use jax.hessians wrto the computed loss to get a matrix of size [sum(prod(Input_i, Output_i)) x sum(prod(Input_i, Output_i))].

An alternative might be to use a Lanczos iteration together with a Hessian-vector product, which you can express easily in JAX. Then you'd only have to deal with vectors, rather than having to worry about how to represent matrices, and we know how to handle tuples/lists/dicts there. (But Lanczos would only be accurate for extremal eigenvalues, and its numerical effectiveness would depend on the conditioning of the Hessian, whereas direct eigenvalue algorithms would be independent of the conditioning.)

I really want to compute the eigenspectrum (density of eigenvalues) and associated eigendirections so accuracy is important.

@mattjj mattjj added enhancement New feature or request question Questions for the JAX team labels Jan 3, 2019
@mattjj
Copy link
Collaborator

mattjj commented Jan 3, 2019

We had some conversations about this, and I think our plan is to:

  1. add flattening utilities (track it in Flattening function like in Autograd #190)
  2. generalize hessian to work on containers (we chose a representation to go with).

@mattjj
Copy link
Collaborator

mattjj commented Jan 7, 2019

I'm glad #201 didn't close this issue, because it seems to have broken hessian! I'm looking at it now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

2 participants