Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pytrees and how we use them. #48

Open
Jordan-Dennis opened this issue Sep 2, 2023 · 2 comments
Open

pytrees and how we use them. #48

Jordan-Dennis opened this issue Sep 2, 2023 · 2 comments

Comments

@Jordan-Dennis
Copy link
Collaborator

equinox has fallen from favour for being too rigid. Instead we are discussing using pytreeclass as the base model for ticktack. I think that ultimately @benjaminpope would prefer to build this on top of zodiax, which is a good solution, but I thought I would make a token effort to look at alternatives. My favourite of these, and one that I will probably look into anyway, is to register the pytree classes manually. The jax API for doing this is actually very simple and if we do decide to remove the Box and Flow classes then we would only need to register a handful of classes, meaning that we would not be adding much boilerplate bloat to the code.

Ultimately, I think that we will use zodiax, but this does raise some further questions, mostly philosophical that I think we should consider before taking the next steps. Since that discussion is erring towards questions of design philosophy rather than functionality I will create a discussion for it and leave this issue as a way to track the different interfaces to pytrees and the progress that I make.

@Jordan-Dennis
Copy link
Collaborator Author

Hi all, I have started the experiment up. I'm expecting that we will encounter issues with side effects based on my first attempt. I wasted a bunch of time in my session today writing some custom syntax formatting files for my editor to get better type hint highlighting, but I'll plan to pick this back up next weekend. Hopefully we can catch up some time during the week though.

@Jordan-Dennis
Copy link
Collaborator Author

So, the experiment was a raging success. Here is a MWE for constructing a custom pytree node.

import jax 
import jax.numpy as np 

@jax.tree_util.register_pytree_node_class
class A(object):
  vector: jax.Array
  def __init__(self, vector: jax.Array):
    self.vector: jax.Array = vector
  def tree_flatten(self):
    return (self.vector, None)
  @classmethod
  def tree_unflatten(cls, _, vector: jax.Array):
    return cls(vector)

@jax.jit
def some_function(vector: A) -> A:
  vector.vector += 1.0
  return vector 

This function runs correctly. However, it is important to notice that the object that is returned is not the same as the object that you put into the function. This is the same as what happened with equinox.

What are peoples thoughts?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant