Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is it possible to use objects with Jax's machine learning library #1967

Closed
Montana opened this issue Jan 9, 2020 · 7 comments
Closed

Is it possible to use objects with Jax's machine learning library #1967

Montana opened this issue Jan 9, 2020 · 7 comments
Assignees
Labels
question Questions for the JAX team

Comments

@Montana
Copy link

Montana commented Jan 9, 2020

Using classes works for me as long as I don't refer to any class objects inside my functions, but was wondering if using objects is possible with Jax?

-Montana

@jbaron
Copy link

jbaron commented Jan 9, 2020

You can use objects, but need to also need register some code how to flatten and unflatten them when passing them to a traceable function. This is however a straight forward process. (some examples can be found here: https://jax.readthedocs.io/en/latest/notebooks/JAX_pytrees.html ).

Alternatively I sometimes use named tuples if I just want to group some parameters in order to increase readability of my code.

@Montana
Copy link
Author

Montana commented Jan 9, 2020

Thanks @jbaron, makes sense. Since I'm expecting NamedTuples to work like other Python builtins. I appreciate your response.

-Montana

@umangjpatel
Copy link

Can you send me an example to do that? Like creating a class and extending it to work with JAX

@mattjj
Copy link
Collaborator

mattjj commented Jan 12, 2020

The technique @jbaron suggests works if you want your classes to be treated as containers, isomorphic to tuples (i.e. product types). The functions you register define the isomorphism: one function for how to take an instance and flatten it to an iterable (plus some metadata, like dict keys), and another for how to take the iterable (plus the metadata) and reproduce the a class instance equal to the original one.

Doing that will let instances of your own container-like classes be passed as arguments and returned as values to and from JAX-transformed functions (like jitted functions). After all, JAX will effectively see them as flat tuples, using the functions you provided to do the conversion back and forth.

from jax import jit

class Special:
  def __init__(self, x, y):
    self.x = x
    self.y = y

@jit
def f(special):
  return special.x + special.y

special = Special(1, 2)

f(special)  # TypeError: Argument '<__main__.Special object at 0x7f1403ee5e10>' of type <class '__main__.Special'> is not a valid JAX type

from jax import tree_util
tree_util.register_pytree_node(Special, lambda s: ((s.x, s.y), None), lambda _, xs: Special(xs[0], xs[1]))

f(special)  # 3

Object identity is not preserved (which is a condition of functional purity):

@jit
def g(s1, s2):
  return s1 is s2

g(special, special)  # False

You can use methods on the instances you pass in, but side-effects won't work (they will likely silently fail rather than raising an error).

As @jbaron said, these things get flattened and then unflattened/re-created when being passed into or out of JAX-transformed functions, so if you keep that in mind it might be easier to remember the constraints (no object identity, no side effects).

Since I think @Montana 's original question was answered, and hopefully @umangjpatel 's request for an example was as well, I'll close this issue. Please open others if more questions come up!

@mattjj mattjj closed this as completed Jan 12, 2020
@mattjj mattjj self-assigned this Jan 12, 2020
@mattjj mattjj added the question Questions for the JAX team label Jan 12, 2020
@Montana
Copy link
Author

Montana commented Feb 1, 2020

This answered my question, thank you!

@samskiter
Copy link

I just wanted to highlight, that I don't think that this answer is entirely true - you can't just apply this method and have grad work totally.
See this discussion - and my specific reproduction
#17341

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 30, 2023

I just wanted to highlight, that I don't think that this answer is entirely true - you can't just apply this method and have grad work totally.

To make this more explicit: objects which have impure methods (i.e. those with side-effects – see JAX Sharp Bits: Pure Functions) will not work correctly with jax.grad and other transformations, even if they are registered as pytrees.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

6 participants