Training Tutorial #14
Replies: 1 comment 7 replies
-
Thanks for the interest in NuX! The flows are built to be as flexible as possible, so you could in theory train them using any of the training frameworks out there for JAX (optax, flax, haiku, etc.). You'd create and initialize the flow using the format in the readme example which is roughly # 1. Create flow
flow = nux.Sequential(...., prior=...)
# 2. Initialize the flow with data
z, log_px = flow(initialization_data, params=None, rng_key=initialization_key)
# 3. Retrieve the initialized parameters
params = flow.get_params()
# 4. Give your JAX framework the parameters
# (Depends on framework. See [here](https://github.com/Information-Fusion-Lab-Umass/NuX/blob/master/examples/haiku_vae.py#L83) for a Haiku example)
# 5. Use the flow in your code
z, log_px = flow(x, params=params, rng_key=rng_key) The FlowTrainer class is another way to train a model. You can take a look at this for an example of how to use it, but to be honest I think training with some of the larger frameworks might be the best way to go because they are more developed. However if you're still interested in how to train flows using the class in this library and want to learn more, I'd be happy to explain more about it. |
Beta Was this translation helpful? Give feedback.
-
Hello everybody!
I would like to know if there is any tutorial about how to train normalizing flows with this library. I've been looking but I haven't found any, I couldn't understand the FlowTrainer class either.
Greetings
David
Beta Was this translation helpful? Give feedback.
All reactions