Skip to content

Commit

Permalink
🔀 Merge pull request #14 from alvarobartt/docs
Browse files Browse the repository at this point in the history
📝 Add `mkdocs` and `mkdocs-material` documenation
  • Loading branch information
alvarobartt authored Dec 26, 2022
2 parents 41c5966 + 740312a commit e49a07e
Show file tree
Hide file tree
Showing 12 changed files with 433 additions and 1 deletion.
24 changes: 23 additions & 1 deletion .github/workflows/ci-cd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,30 @@ jobs:
- name: run-tests
run: pytest tests/ -s --durations 0

publish-package:
deploy-docs:
needs: run-tests

runs-on: ubuntu-latest

steps:
- name: checkout
uses: actions/checkout@v3

- name: setup-python
uses: actions/setup-python@v4
with:
python-version: 3.8

- name: install-dependencies
run: |
python -m pip install --upgrade pip
pip install -e ".[docs]"
- name: deploy-to-gh-pages
run: mkdocs gh-deploy --force

publish-package:
needs: deploy-docs
if: github.event_name == 'release'

runs-on: ubuntu-latest
Expand Down
2 changes: 2 additions & 0 deletions docs/api/load.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
::: safejax.load
handler: python
2 changes: 2 additions & 0 deletions docs/api/save.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
::: safejax.save
handler: python
206 changes: 206 additions & 0 deletions docs/examples.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# 🤖 Examples

Here you will find some detailed examples of how to use `safejax` to serialize
model parameters, in opposition to the default way to store those, which uses
`pickle` as the format to store the tensors instead of `safetensors`.

## **Flax**

Available at [`flax_ft_safejax.py`](https://github.com/alvarobartt/safejax/examples/flax_ft_safejax.py).

To run this Python script you won't need to install anything else than
`safejax`, as both `jax` and `flax` are installed as part of it.

In this case, a single-layer model will be created, for now, `flax`
doesn't have any pre-defined architecture such as ResNet, but you can use
[`flaxmodels`](https://github.com/matthias-wright/flaxmodels) for that, as
it defines some well-known architectures written in `flax`.

```python
import jax
from flax import linen as nn

class SingleLayerModel(nn.Module):
features: int

@nn.compact
def __call__(self, x):
x = nn.Dense(features=self.features)(x)
return x
```

Once the network has been defined, we can instantiate and initialize it,
to retrieve the `params` out of the forward pass performed during
`.init`.

```python
import jax
from jax import numpy as jnp

network = SingleLayerModel(features=1)

rng_key = jax.random.PRNGKey(seed=0)
initial_params = network.init(rng_key, jnp.ones((1, 1)))
```

Right after getting the `params` from the `.init` method's output, we can
use `safejax.serialize` to encode those using `safetensors`, that later on
can be loaded back using `safejax.deserialize`.

```python
from safejax import deserialize, serialize

encoded_bytes = serialize(params=initial_params)
decoded_params = deserialize(path_or_buf=encoded_bytes, freeze_dict=True)
```

As seen in the code above, we're using `freeze_dict=True` since its default
value is False, as we want to freeze the `dict` with the params before actually
returning it during `safejax.deserialize`, this transforms the `Dict`
into a `FrozenDict`.

Finally, we can use those `decoded_params` to run a forward pass
with the previously defined single-layer network.

```python
x = jnp.ones((1, 1))
y = network.apply(decoded_params, x)
```


## **Haiku**

Available at [`haiku_ft_safejax.py`](https://github.com/alvarobartt/safejax/examples/haiku_ft_safejax.py).

To run this Python script you'll need to have both `safejax` and [`dm-haiku`](https://github.com/deepmind/dm-haiku)
installed.

A ResNet50 architecture will be used from `haiku.nets.imagenet.resnet` and since
the purpose of the example is to show the integration of both `dm-haiku` and
`safejax`, we won't use pre-trained weights.

If you're not familiar with `dm-haiku`, please visit [Haiku Basics](https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html).

First of all, let's create the network instance for the ResNet50 using `dm-haiku`
with the following code:

```python
import haiku as hk
from jax import numpy as jnp

def resnet_fn(x: jnp.DeviceArray, is_training: bool):
resnet = hk.nets.ResNet50(num_classes=10)
return resnet(x, is_training=is_training)

network = hk.without_apply_rng(hk.transform_with_state(resnet_fn))
```

Some notes on the code above:
* `haiku.nets.ResNet50` requires `num_classes` as a mandatory parameter
* `haiku.nets.ResNet50.__call__` requires `is_training` as a mandatory parameter
* It needs to be initialized with `hk.transform_with_state` as we want to preserve
the state e.g. ExponentialMovingAverage in BatchNorm. More information at https://dm-haiku.readthedocs.io/en/latest/api.html#transform-with-state.
* Using `hk.without_apply_rng` removes the `rng` arg in the `.apply` function. More information at https://dm-haiku.readthedocs.io/en/latest/api.html#without-apply-rng.

Then we just initialize the network to retrieve both the `params` and the `state`,
which again, are random.

```python
import jax

rng_key = jax.random.PRNGKey(seed=0)
initial_params, initial_state = network.init(
rng_key, jnp.ones([1, 224, 224, 3]), is_training=True
)
```

Now once we have the `params`, we can import `safejax.serialize` to serialize the
params using `safetensors` as the tensor storage format, that later on can be loaded
back using `safejax.deserialize` and used for the network's inference.

```python
from safejax import deserialize, serialize

encoded_bytes = serialize(params=initial_params)
decoded_params = deserialize(path_or_buf=encoded_bytes)
```

Finally, let's just use those `decoded_params` to run the inference over the network
using those weights.

```python
x = jnp.ones([1, 224, 224, 3])
y, _ = network.apply(decoded_params, initial_state, x, is_training=False)
```

## **Objax**

Available at [`objax_ft_safejax.py`](https://github.com/alvarobartt/safejax/examples/objax_ft_safejax.py).

To run this Python script you won't need to install anything else than
`safejax`, as both `jax` and `objax` are installed as part of it.

In this case, we'll be using one of the architectures defined in the model zoo
of `objax` at [`objax/zoo`](https://github.com/google/objax/tree/master/objax/zoo),
which is ResNet50. So first of all, let's initialize it:

```python
from objax.zoo.resnet_v2 import ResNet50

model = ResNet50(in_channels=3, num_classes=1000)
```

Once initialized, we can already access the model params which in `objax` are stored
in `model.vars()` and are of type `VarCollection` which is a dictionary-like class. So
on, we can already serialize those using `safejax.serialize` and `safetensors` format
instead of `pickle` which is the current recommended way, see https://objax.readthedocs.io/en/latest/advanced/io.html.

```python
from safejax import serialize

encoded_bytes = serialize(params=model.vars())
```

Then we can just deserialize those params back using `safejax.deserialize`, and
we'll end up getting the same `VarCollection` dictionary back. Note that we need
to disable the unflattening with `requires_unflattening=False` as it's not required
due to the way it's stored, and set `to_var_collection=True` to get a `VarCollection`
instead of a `Dict[str, jnp.DeviceArray]`, even though it will work with a standard
dict too.

```python
from safejax import deserialize

decoded_params = deserialize(
encoded_bytes, requires_unflattening=False, to_var_collection=True
)
```

Now, once decoded with `safejax.deserialize` we need to assign those key-value
pais back to the `VarCollection` of the ResNet50 via assignment, as `.update` in
`objax` has been redefined, see https://github.com/google/objax/blob/53b391bfa72dc59009c855d01b625049a35f5f1b/objax/variable.py#L311,
and it's not consistent with the standard `dict.update` (already reported at
https://github.com/google/objax/issues/254). So, instead, we need to loop over
all the key-value pairs in the decoded params and assign those one by one to the
`VarCollection` in `model.vars()`.

```python
for key, value in decoded_params.items():
if key not in model.vars():
print(f"Key {key} not in model.vars()! Skipping.")
continue
model.vars()[key].assign(value)
```

And, finally, we can run the inference over the model via the `__call__` method
as the `.vars()` are already copied from the params resulting of `safejax.deserialize`.

```python
from jax import numpy as jnp

x = jnp.ones((1, 3, 224, 224))
y = model(x, training=False)
```

Note that we're setting the `training` flag to `False`, which is the standard way
of running the inference over a pre-trained model in `objax`.
10 changes: 10 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# 🔐 Serialize JAX, Flax, Haiku, or Objax model params with `safetensors`

`safejax` is a Python package to serialize JAX, Flax, Haiku, or Objax model params using `safetensors`
as the tensor storage format, instead of relying on `pickle`. For more details on why
`safetensors` is safer than `pickle` please check https://github.com/huggingface/safetensors.

Note that `safejax` supports the serialization of `jax`, `flax`, `dm-haiku`, and `objax` model
parameters and has been tested with all those frameworks, but there may be some cases where it
does not work as expected, as this is still in an early development phase, so please if you have
any feedback or bug reports, open an issue at https://github.com/alvarobartt/safejax/issues
5 changes: 5 additions & 0 deletions docs/installation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# ⬇️ Installation

```bash
pip install safejax --upgrade
```
23 changes: 23 additions & 0 deletions docs/license.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# 📝 License

MIT License

Copyright (c) 2022-present Alvaro Bartolome <alvarobartt@yahoo.com>

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
3 changes: 3 additions & 0 deletions docs/requirements.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# 🛠️ Requirements

`safejax` requires Python 3.7 or above
49 changes: 49 additions & 0 deletions docs/usage.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# 💻 Usage

Let's create a `flax` model using the Linen API and initialize it.

```python
import jax
from flax import linen as nn
from jax import numpy as jnp

class SingleLayerModel(nn.Module):
features: int

@nn.compact
def __call__(self, x):
x = nn.Dense(features=self.features)(x)
return x

model = SingleLayerModel(features=1)

rng = jax.random.PRNGKey(0)
params = model.init(rng, jnp.ones((1, 1)))
```

Once done, we can already save the model params with `safejax` (using `safetensors`
storage format) using `safejax.serialize`.

```python
from safejax import serialize

serialized_params = serialize(params=params)
```

Those params can be later loaded using `safejax.deserialize` and used
to run the inference over the model using those weights.

```python
from safejax import deserialize

params = deserialize(path_or_buf=serialized_params, freeze_dict=True)
```

And, finally, running the inference as:

```python
x = jnp.ones((1, 28, 28, 1))
y = model.apply(params, x)
```

More in-detail examples can be found at [`examples/`](./examples) for `flax`, `dm-haiku`, and `objax`.
32 changes: 32 additions & 0 deletions docs/why_safejax.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# 🤔 Why `safejax`?

`safetensors` defines an easy and fast (zero-copy) format to store tensors,
while `pickle` has some known weaknesses and security issues. `safetensors`
is also a storage format that is intended to be trivial to the framework
used to load the tensors. More in-depth information can be found at
https://github.com/huggingface/safetensors.

Both `jax` and `haiku` use `pytrees` to store the model parameters in memory, so
it's a dictionary-like class containing nested `jnp.DeviceArray` tensors.

Then `objax` defines a custom dictionary-like class named `VarCollection` that contains
some variables inheriting from `BaseVar` which is another custom `objax` type.

`flax` defines a dictionary-like class named `FrozenDict` that is used to
store the tensors in memory, it can be dumped either into `bytes` in `MessagePack`
format or as a `state_dict`.

Anyway, `flax` still uses `pickle` as the format for storing the tensors, so
there are no plans from HuggingFace to extend `safetensors` to support anything
more than tensors e.g. `FrozenDict`s, see their response at
https://github.com/huggingface/safetensors/discussions/138.

So `safejax` was created to easily provide a way to serialize `FrozenDict`s
using `safetensors` as the tensor storage format instead of `pickle`.

## 📄 Main differences with `flax.serialization`

* `flax.serialization.to_bytes` uses `pickle` as the tensor storage format, while
`safejax.serialize` uses `safetensors`
* `flax.serialization.from_bytes` requires the `target` to be instantiated, while
`safejax.deserialize` just needs the encoded bytes
Loading

0 comments on commit e49a07e

Please sign in to comment.