Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flattening function like in Autograd #190

Closed
sscardapane opened this issue Jan 3, 2019 · 5 comments · Fixed by #201
Closed

Flattening function like in Autograd #190

sscardapane opened this issue Jan 3, 2019 · 5 comments · Fixed by #201
Assignees
Labels
enhancement New feature or request

Comments

@sscardapane
Copy link

I was reimplementing my old Autograd code in JAX, but I can't seem to find an equivalent of the "flatten" utility:
https://github.com/HIPS/autograd/tree/master/autograd/misc

Basically, given a list of the network's parameters w, I would need something like:

w_flat, unflattener = flatten(w)

where w_flat is a one-dimensional vector containing all parameters, and unflattener reverts this operation. If I understand correctly, this was used in Autograd at the beginning for the implementation of the optimizers, but it is not needed here. However, it is useful in many cases and also goes well with the idea of chainable transformations on data.

@mattjj mattjj self-assigned this Jan 3, 2019
@mattjj mattjj added the enhancement New feature or request label Jan 3, 2019
@mattjj
Copy link
Collaborator

mattjj commented Jan 6, 2019

I added some simple utilities in #201, so you should be able to write

from jax.flatten_util import ravel_pytree

w_flat, unflattener = ravel_pytree(w)

I changed the name from "flatten" to "ravel_pytree" because I think it's more descriptive and in-line with other names in JAX. At one point we had too many uses of "flatten" and "unflatten" to mean slightly different things.

Please re-open this issue if you have issues with that function, or if I missed the mark somehow.

@neonwatty
Copy link

Is there a flatten_func analog?

@neonwatty
Copy link

@mattjj - I'm probably missing it, is a flatten_func analog baked into JAX somewhere? I may be missing it in the flatten_util (?) or tree_util files?

@mattjj
Copy link
Collaborator

mattjj commented Mar 26, 2019

Thanks for the ping. My notification setup isn't great. When in doubt, open a new issue, since those are much more visible. It also makes things like feature requests easier to track.

We haven't added at flatten_func to JAX, but it's probably a good idea to put it in flatten_util.py and make it call ravel_pytree. I think it should look about the same as Autograd's. We could call it ravel_func or flatten_func, I don't have a strong opinion.

Want to send us a PR with that change?

@neonwatty
Copy link

Great! - will do.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants