Replies: 1 comment
-
Hey @ariG23498, you can create use def put_model(model, device):
state = nnx.state(model)
state = jax.device_put(state, device)
nnx.update(model, state) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hey folks!
Is there a one stop solution for onloading and offloading a model to and from any accelerated device (GPU, TPUs)?
I am working on a diffusion model, that has 4 models in total (2 text encoders, 1 flow models, and an autoencoder). I would like to juggle between loading and offloading the models for better memory management.
Any help would be great!
Beta Was this translation helpful? Give feedback.
All reactions