-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Is it possible to use objects with Jax's machine learning library #1967
Comments
You can use objects, but need to also need register some code how to flatten and unflatten them when passing them to a traceable function. This is however a straight forward process. (some examples can be found here: https://jax.readthedocs.io/en/latest/notebooks/JAX_pytrees.html ). Alternatively I sometimes use named tuples if I just want to group some parameters in order to increase readability of my code. |
Thanks @jbaron, makes sense. Since I'm expecting NamedTuples to work like other Python builtins. I appreciate your response. -Montana |
Can you send me an example to do that? Like creating a class and extending it to work with JAX |
The technique @jbaron suggests works if you want your classes to be treated as containers, isomorphic to tuples (i.e. product types). The functions you register define the isomorphism: one function for how to take an instance and flatten it to an iterable (plus some metadata, like dict keys), and another for how to take the iterable (plus the metadata) and reproduce the a class instance equal to the original one. Doing that will let instances of your own container-like classes be passed as arguments and returned as values to and from JAX-transformed functions (like jitted functions). After all, JAX will effectively see them as flat tuples, using the functions you provided to do the conversion back and forth. from jax import jit
class Special:
def __init__(self, x, y):
self.x = x
self.y = y
@jit
def f(special):
return special.x + special.y
special = Special(1, 2)
f(special) # TypeError: Argument '<__main__.Special object at 0x7f1403ee5e10>' of type <class '__main__.Special'> is not a valid JAX type
from jax import tree_util
tree_util.register_pytree_node(Special, lambda s: ((s.x, s.y), None), lambda _, xs: Special(xs[0], xs[1]))
f(special) # 3 Object identity is not preserved (which is a condition of functional purity): @jit
def g(s1, s2):
return s1 is s2
g(special, special) # False You can use methods on the instances you pass in, but side-effects won't work (they will likely silently fail rather than raising an error). As @jbaron said, these things get flattened and then unflattened/re-created when being passed into or out of JAX-transformed functions, so if you keep that in mind it might be easier to remember the constraints (no object identity, no side effects). Since I think @Montana 's original question was answered, and hopefully @umangjpatel 's request for an example was as well, I'll close this issue. Please open others if more questions come up! |
This answered my question, thank you! |
I just wanted to highlight, that I don't think that this answer is entirely true - you can't just apply this method and have grad work totally. |
To make this more explicit: objects which have impure methods (i.e. those with side-effects – see JAX Sharp Bits: Pure Functions) will not work correctly with |
Using classes works for me as long as I don't refer to any class objects inside my functions, but was wondering if using objects is possible with Jax?
-Montana
The text was updated successfully, but these errors were encountered: