Skip to content

Commit

Permalink
🔀 Merge pull request #16 from alvarobartt/readme-update
Browse files Browse the repository at this point in the history
📝 Update `README.md` and `docs/usage.md`
  • Loading branch information
alvarobartt authored Dec 27, 2022
2 parents 224f8dd + 200c4a9 commit 3ac50f6
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 66 deletions.
139 changes: 105 additions & 34 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

`safejax` is a Python package to serialize JAX, Flax, Haiku, or Objax model params using `safetensors`
as the tensor storage format, instead of relying on `pickle`. For more details on why
`safetensors` is safer than `pickle` please check https://github.com/huggingface/safetensors.
`safetensors` is safer than `pickle` please check [huggingface/tensors](https://github.com/huggingface/safetensors).

Note that `safejax` supports the serialization of `jax`, `flax`, `dm-haiku`, and `objax` model
parameters and has been tested with all those frameworks, but there may be some cases where it
does not work as expected, as this is still in an early development phase, so please if you have
any feedback or bug reports, please open an issue at https://github.com/alvarobartt/safejax/issues
any feedback or bug reports, please open an issue at [safejax/issues](https://github.com/alvarobartt/safejax/issues).

## 🛠️ Requirements & Installation

Expand All @@ -19,51 +19,122 @@ pip install safejax --upgrade

## 💻 Usage

Let's create a `flax` model using the Linen API and initialize it.
### `flax`

```python
import jax
from flax import linen as nn
from jax import numpy as jnp
* Convert `params` to `bytes` in memory

class SingleLayerModel(nn.Module):
features: int
```python
from safejax import serialize, deserialize

@nn.compact
def __call__(self, x):
x = nn.Dense(features=self.features)(x)
return x
params = model.init(...)

model = SingleLayerModel(features=1)
encoded_bytes = serialize(params)
decoded_params = deserialize(encoded_bytes, freeze_dict=True)

rng = jax.random.PRNGKey(0)
params = model.init(rng, jnp.ones((1, 1)))
```
model.apply(decoded_params, ...)
```

Once done, we can already save the model params with `safejax` (using `safetensors`
storage format) using `safejax.serialize`.
* Convert `params` to `bytes` in `params.safetensors` file

```python
from safejax import serialize
```python
from safejax import serialize, deserialize

serialized_params = serialize(params=params)
```
params = model.init(...)

Those params can be later loaded using `safejax.deserialize` and used
to run the inference over the model using those weights.
encoded_bytes = serialize(params, filename="./params.safetensors")
decoded_params = deserialize("./params.safetensors", freeze_dict=True)

```python
from safejax import deserialize
model.apply(decoded_params, ...)
```

params = deserialize(path_or_buf=serialized_params, freeze_dict=True)
```
---

And, finally, running the inference as:
### `dm-haiku`

```python
x = jnp.ones((1, 28, 28, 1))
y = model.apply(params, x)
```
* Just contains `params`

```python
from safejax import serialize, deserialize

params = model.init(...)

encoded_bytes = serialize(params)
decoded_params = deserialize(encoded_bytes)

model.apply(decoded_params, ...)
```

* If it contains `params` and `state` e.g. ExponentialMovingAverage in BatchNorm

```python
from safejax import serialize, deserialize

params, state = model.init(...)
params_state = {"params": params, "state": state}

encoded_bytes = serialize(params_state)
decoded_params_state = deserialize(encoded_bytes) # .keys() contains `params` and `state`

model.apply(decoded_params_state["params"], decoded_params_state["state"], ...)
```

* If it contains `params` and `state`, but we want to serialize those individually

```python
from safejax import serialize, deserialize

params, state = model.init(...)

encoded_bytes = serialize(params)
decoded_params = deserialize(encoded_bytes)

encoded_bytes = serialize(state)
decoded_state = deserialize(encoded_bytes)

model.apply(decoded_params, decoded_state, ...)
```

---

### `objax`

* Convert `params` to `bytes` in memory, and convert back to `VarCollection`

```python
from safejax import serialize, deserialize

params = model.vars()

encoded_bytes = serialize(params=params)
decoded_params = deserialize(
encoded_bytes, requires_unflattening=False, to_var_collection=True
)

for key, value in decoded_params.items():
if key in model.vars():
model.vars()[key].assign(value)

model(...)
```

* Convert `params` to `bytes` in `params.safetensors` file

```python
from safejax import serialize, deserialize

params = model.vars()

encoded_bytes = serialize(params=params, filename="./params.safetensors")
decoded_params = deserialize("./params.safetensors", requires_unflattening=False)

for key, value in decoded_params.items():
if key in model.vars():
model.vars()[key].assign(value)

model(...)
```

---

More in-detail examples can be found at [`examples/`](./examples) for `flax`, `dm-haiku`, and `objax`.

Expand Down
135 changes: 103 additions & 32 deletions docs/usage.md
Original file line number Diff line number Diff line change
@@ -1,49 +1,120 @@
# 💻 Usage

Let's create a `flax` model using the Linen API and initialize it.
## `flax`

```python
import jax
from flax import linen as nn
from jax import numpy as jnp
* Convert `params` to `bytes` in memory

class SingleLayerModel(nn.Module):
features: int
```python
from safejax import serialize, deserialize

@nn.compact
def __call__(self, x):
x = nn.Dense(features=self.features)(x)
return x
params = model.init(...)

model = SingleLayerModel(features=1)
encoded_bytes = serialize(params)
decoded_params = deserialize(encoded_bytes, freeze_dict=True)

rng = jax.random.PRNGKey(0)
params = model.init(rng, jnp.ones((1, 1)))
```
model.apply(decoded_params, ...)
```

Once done, we can already save the model params with `safejax` (using `safetensors`
storage format) using `safejax.serialize`.
* Convert `params` to `bytes` in `params.safetensors` file

```python
from safejax import serialize
```python
from safejax import serialize, deserialize

serialized_params = serialize(params=params)
```
params = model.init(...)

Those params can be later loaded using `safejax.deserialize` and used
to run the inference over the model using those weights.
encoded_bytes = serialize(params, filename="./params.safetensors")
decoded_params = deserialize("./params.safetensors", freeze_dict=True)

```python
from safejax import deserialize
model.apply(decoded_params, ...)
```

params = deserialize(path_or_buf=serialized_params, freeze_dict=True)
```
---

And, finally, running the inference as:
## `dm-haiku`

```python
x = jnp.ones((1, 28, 28, 1))
y = model.apply(params, x)
```
* Just contains `params`

```python
from safejax import serialize, deserialize

params = model.init(...)

encoded_bytes = serialize(params)
decoded_params = deserialize(encoded_bytes)

model.apply(decoded_params, ...)
```

* If it contains `params` and `state` e.g. ExponentialMovingAverage in BatchNorm

```python
from safejax import serialize, deserialize

params, state = model.init(...)
params_state = {"params": params, "state": state}

encoded_bytes = serialize(params_state)
decoded_params_state = deserialize(encoded_bytes) # .keys() contains `params` and `state`

model.apply(decoded_params_state["params"], decoded_params_state["state"], ...)
```

* If it contains `params` and `state`, but we want to serialize those individually

```python
from safejax import serialize, deserialize

params, state = model.init(...)

encoded_bytes = serialize(params)
decoded_params = deserialize(encoded_bytes)

encoded_bytes = serialize(state)
decoded_state = deserialize(encoded_bytes)

model.apply(decoded_params, decoded_state, ...)
```

---

## `objax`

* Convert `params` to `bytes` in memory, and convert back to `VarCollection`

```python
from safejax import serialize, deserialize

params = model.vars()

encoded_bytes = serialize(params=params)
decoded_params = deserialize(
encoded_bytes, requires_unflattening=False, to_var_collection=True
)

for key, value in decoded_params.items():
if key in model.vars():
model.vars()[key].assign(value)

model(...)
```

* Convert `params` to `bytes` in `params.safetensors` file

```python
from safejax import serialize, deserialize

params = model.vars()

encoded_bytes = serialize(params=params, filename="./params.safetensors")
decoded_params = deserialize("./params.safetensors", requires_unflattening=False)

for key, value in decoded_params.items():
if key in model.vars():
model.vars()[key].assign(value)

model(...)
```

---

More in-detail examples can be found at [`examples/`](./examples) for `flax`, `dm-haiku`, and `objax`.

0 comments on commit 3ac50f6

Please sign in to comment.