Haiku integration #5
-
Hi all, Very nice package. Just starting to port my model to it. I was wondering what you'd recommend as the best way to integrate a flow with a larger Haiku model? To give a bit of context: I am using a custom encoder to encode a high-dimensional dataset, and then using a flow conditioned on the encoded vector to model a distribution. I noticed the trainer class seems like it might be able to handle general loss functions - would you recommend using that for my larger Haiku model? (and pass the loss_fn) Or instead, would it be better I use the Thanks!! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
Hi Miles,
Then I think it would be a bit more 'jax-y' if the |
Beta Was this translation helpful? Give feedback.
-
Hi Miles, Unfortunately I'm not sure how easy it will be to use NuX in a Haiku model. The reason is that NuX is built on Haiku but adds a bit of extra functionality at a pretty low level that might make combining Haiku and NuX difficult. While Haiku keeps track of parameters and state, NuX also tracks compile time constants behind the scenes. This is mainly used to save off the shapes of inputs at initialization so that at runtime, the flow layers know how inputs are batched and can deal with batches automatically. I have not done any testing to see if you can combine NuX layers and Haiku modules, but it might be possible? That being said, the Layer class in NuX is a wrapped version of the Haiku module, so it should be straight forward to rewrite your Haiku model as a NuX Layer. Check out nux.networks for examples of this. I also highly recommend that you use the trainer class to define a new loss function. You can do this by creating a new class that inherits from FlowTrainer and write a new loss function. The max likelihood and classification trainer are examples of this. I'll write a tutorial with more details over the next couple of days. Until then, if you give the implementation a shot I'll be happy to answer any questions you have. |
Beta Was this translation helpful? Give feedback.
Hi Miles,
Giving my two cents as a user, this is what I use to integrate a flow from NuX into a bigger model:
Then
sample_flow
is a pure function mappingparams
,state
,rng_key
to samples and log prob…