generated from alvarobartt/python-package-template
-
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
🔀 Merge pull request #16 from alvarobartt/readme-update
📝 Update `README.md` and `docs/usage.md`
- Loading branch information
Showing
2 changed files
with
208 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,49 +1,120 @@ | ||
# 💻 Usage | ||
|
||
Let's create a `flax` model using the Linen API and initialize it. | ||
## `flax` | ||
|
||
```python | ||
import jax | ||
from flax import linen as nn | ||
from jax import numpy as jnp | ||
* Convert `params` to `bytes` in memory | ||
|
||
class SingleLayerModel(nn.Module): | ||
features: int | ||
```python | ||
from safejax import serialize, deserialize | ||
|
||
@nn.compact | ||
def __call__(self, x): | ||
x = nn.Dense(features=self.features)(x) | ||
return x | ||
params = model.init(...) | ||
|
||
model = SingleLayerModel(features=1) | ||
encoded_bytes = serialize(params) | ||
decoded_params = deserialize(encoded_bytes, freeze_dict=True) | ||
|
||
rng = jax.random.PRNGKey(0) | ||
params = model.init(rng, jnp.ones((1, 1))) | ||
``` | ||
model.apply(decoded_params, ...) | ||
``` | ||
|
||
Once done, we can already save the model params with `safejax` (using `safetensors` | ||
storage format) using `safejax.serialize`. | ||
* Convert `params` to `bytes` in `params.safetensors` file | ||
|
||
```python | ||
from safejax import serialize | ||
```python | ||
from safejax import serialize, deserialize | ||
|
||
serialized_params = serialize(params=params) | ||
``` | ||
params = model.init(...) | ||
|
||
Those params can be later loaded using `safejax.deserialize` and used | ||
to run the inference over the model using those weights. | ||
encoded_bytes = serialize(params, filename="./params.safetensors") | ||
decoded_params = deserialize("./params.safetensors", freeze_dict=True) | ||
|
||
```python | ||
from safejax import deserialize | ||
model.apply(decoded_params, ...) | ||
``` | ||
|
||
params = deserialize(path_or_buf=serialized_params, freeze_dict=True) | ||
``` | ||
--- | ||
|
||
And, finally, running the inference as: | ||
## `dm-haiku` | ||
|
||
```python | ||
x = jnp.ones((1, 28, 28, 1)) | ||
y = model.apply(params, x) | ||
``` | ||
* Just contains `params` | ||
|
||
```python | ||
from safejax import serialize, deserialize | ||
|
||
params = model.init(...) | ||
|
||
encoded_bytes = serialize(params) | ||
decoded_params = deserialize(encoded_bytes) | ||
|
||
model.apply(decoded_params, ...) | ||
``` | ||
|
||
* If it contains `params` and `state` e.g. ExponentialMovingAverage in BatchNorm | ||
|
||
```python | ||
from safejax import serialize, deserialize | ||
|
||
params, state = model.init(...) | ||
params_state = {"params": params, "state": state} | ||
|
||
encoded_bytes = serialize(params_state) | ||
decoded_params_state = deserialize(encoded_bytes) # .keys() contains `params` and `state` | ||
|
||
model.apply(decoded_params_state["params"], decoded_params_state["state"], ...) | ||
``` | ||
|
||
* If it contains `params` and `state`, but we want to serialize those individually | ||
|
||
```python | ||
from safejax import serialize, deserialize | ||
|
||
params, state = model.init(...) | ||
|
||
encoded_bytes = serialize(params) | ||
decoded_params = deserialize(encoded_bytes) | ||
|
||
encoded_bytes = serialize(state) | ||
decoded_state = deserialize(encoded_bytes) | ||
|
||
model.apply(decoded_params, decoded_state, ...) | ||
``` | ||
|
||
--- | ||
|
||
## `objax` | ||
|
||
* Convert `params` to `bytes` in memory, and convert back to `VarCollection` | ||
|
||
```python | ||
from safejax import serialize, deserialize | ||
|
||
params = model.vars() | ||
|
||
encoded_bytes = serialize(params=params) | ||
decoded_params = deserialize( | ||
encoded_bytes, requires_unflattening=False, to_var_collection=True | ||
) | ||
|
||
for key, value in decoded_params.items(): | ||
if key in model.vars(): | ||
model.vars()[key].assign(value) | ||
|
||
model(...) | ||
``` | ||
|
||
* Convert `params` to `bytes` in `params.safetensors` file | ||
|
||
```python | ||
from safejax import serialize, deserialize | ||
|
||
params = model.vars() | ||
|
||
encoded_bytes = serialize(params=params, filename="./params.safetensors") | ||
decoded_params = deserialize("./params.safetensors", requires_unflattening=False) | ||
|
||
for key, value in decoded_params.items(): | ||
if key in model.vars(): | ||
model.vars()[key].assign(value) | ||
|
||
model(...) | ||
``` | ||
|
||
--- | ||
|
||
More in-detail examples can be found at [`examples/`](./examples) for `flax`, `dm-haiku`, and `objax`. |