generated from alvarobartt/python-package-template
-
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
🔀 Merge pull request #14 from alvarobartt/docs
📝 Add `mkdocs` and `mkdocs-material` documenation
- Loading branch information
Showing
12 changed files
with
433 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
::: safejax.load | ||
handler: python |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
::: safejax.save | ||
handler: python |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
# 🤖 Examples | ||
|
||
Here you will find some detailed examples of how to use `safejax` to serialize | ||
model parameters, in opposition to the default way to store those, which uses | ||
`pickle` as the format to store the tensors instead of `safetensors`. | ||
|
||
## **Flax** | ||
|
||
Available at [`flax_ft_safejax.py`](https://github.com/alvarobartt/safejax/examples/flax_ft_safejax.py). | ||
|
||
To run this Python script you won't need to install anything else than | ||
`safejax`, as both `jax` and `flax` are installed as part of it. | ||
|
||
In this case, a single-layer model will be created, for now, `flax` | ||
doesn't have any pre-defined architecture such as ResNet, but you can use | ||
[`flaxmodels`](https://github.com/matthias-wright/flaxmodels) for that, as | ||
it defines some well-known architectures written in `flax`. | ||
|
||
```python | ||
import jax | ||
from flax import linen as nn | ||
|
||
class SingleLayerModel(nn.Module): | ||
features: int | ||
|
||
@nn.compact | ||
def __call__(self, x): | ||
x = nn.Dense(features=self.features)(x) | ||
return x | ||
``` | ||
|
||
Once the network has been defined, we can instantiate and initialize it, | ||
to retrieve the `params` out of the forward pass performed during | ||
`.init`. | ||
|
||
```python | ||
import jax | ||
from jax import numpy as jnp | ||
|
||
network = SingleLayerModel(features=1) | ||
|
||
rng_key = jax.random.PRNGKey(seed=0) | ||
initial_params = network.init(rng_key, jnp.ones((1, 1))) | ||
``` | ||
|
||
Right after getting the `params` from the `.init` method's output, we can | ||
use `safejax.serialize` to encode those using `safetensors`, that later on | ||
can be loaded back using `safejax.deserialize`. | ||
|
||
```python | ||
from safejax import deserialize, serialize | ||
|
||
encoded_bytes = serialize(params=initial_params) | ||
decoded_params = deserialize(path_or_buf=encoded_bytes, freeze_dict=True) | ||
``` | ||
|
||
As seen in the code above, we're using `freeze_dict=True` since its default | ||
value is False, as we want to freeze the `dict` with the params before actually | ||
returning it during `safejax.deserialize`, this transforms the `Dict` | ||
into a `FrozenDict`. | ||
|
||
Finally, we can use those `decoded_params` to run a forward pass | ||
with the previously defined single-layer network. | ||
|
||
```python | ||
x = jnp.ones((1, 1)) | ||
y = network.apply(decoded_params, x) | ||
``` | ||
|
||
|
||
## **Haiku** | ||
|
||
Available at [`haiku_ft_safejax.py`](https://github.com/alvarobartt/safejax/examples/haiku_ft_safejax.py). | ||
|
||
To run this Python script you'll need to have both `safejax` and [`dm-haiku`](https://github.com/deepmind/dm-haiku) | ||
installed. | ||
|
||
A ResNet50 architecture will be used from `haiku.nets.imagenet.resnet` and since | ||
the purpose of the example is to show the integration of both `dm-haiku` and | ||
`safejax`, we won't use pre-trained weights. | ||
|
||
If you're not familiar with `dm-haiku`, please visit [Haiku Basics](https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html). | ||
|
||
First of all, let's create the network instance for the ResNet50 using `dm-haiku` | ||
with the following code: | ||
|
||
```python | ||
import haiku as hk | ||
from jax import numpy as jnp | ||
|
||
def resnet_fn(x: jnp.DeviceArray, is_training: bool): | ||
resnet = hk.nets.ResNet50(num_classes=10) | ||
return resnet(x, is_training=is_training) | ||
|
||
network = hk.without_apply_rng(hk.transform_with_state(resnet_fn)) | ||
``` | ||
|
||
Some notes on the code above: | ||
* `haiku.nets.ResNet50` requires `num_classes` as a mandatory parameter | ||
* `haiku.nets.ResNet50.__call__` requires `is_training` as a mandatory parameter | ||
* It needs to be initialized with `hk.transform_with_state` as we want to preserve | ||
the state e.g. ExponentialMovingAverage in BatchNorm. More information at https://dm-haiku.readthedocs.io/en/latest/api.html#transform-with-state. | ||
* Using `hk.without_apply_rng` removes the `rng` arg in the `.apply` function. More information at https://dm-haiku.readthedocs.io/en/latest/api.html#without-apply-rng. | ||
|
||
Then we just initialize the network to retrieve both the `params` and the `state`, | ||
which again, are random. | ||
|
||
```python | ||
import jax | ||
|
||
rng_key = jax.random.PRNGKey(seed=0) | ||
initial_params, initial_state = network.init( | ||
rng_key, jnp.ones([1, 224, 224, 3]), is_training=True | ||
) | ||
``` | ||
|
||
Now once we have the `params`, we can import `safejax.serialize` to serialize the | ||
params using `safetensors` as the tensor storage format, that later on can be loaded | ||
back using `safejax.deserialize` and used for the network's inference. | ||
|
||
```python | ||
from safejax import deserialize, serialize | ||
|
||
encoded_bytes = serialize(params=initial_params) | ||
decoded_params = deserialize(path_or_buf=encoded_bytes) | ||
``` | ||
|
||
Finally, let's just use those `decoded_params` to run the inference over the network | ||
using those weights. | ||
|
||
```python | ||
x = jnp.ones([1, 224, 224, 3]) | ||
y, _ = network.apply(decoded_params, initial_state, x, is_training=False) | ||
``` | ||
|
||
## **Objax** | ||
|
||
Available at [`objax_ft_safejax.py`](https://github.com/alvarobartt/safejax/examples/objax_ft_safejax.py). | ||
|
||
To run this Python script you won't need to install anything else than | ||
`safejax`, as both `jax` and `objax` are installed as part of it. | ||
|
||
In this case, we'll be using one of the architectures defined in the model zoo | ||
of `objax` at [`objax/zoo`](https://github.com/google/objax/tree/master/objax/zoo), | ||
which is ResNet50. So first of all, let's initialize it: | ||
|
||
```python | ||
from objax.zoo.resnet_v2 import ResNet50 | ||
|
||
model = ResNet50(in_channels=3, num_classes=1000) | ||
``` | ||
|
||
Once initialized, we can already access the model params which in `objax` are stored | ||
in `model.vars()` and are of type `VarCollection` which is a dictionary-like class. So | ||
on, we can already serialize those using `safejax.serialize` and `safetensors` format | ||
instead of `pickle` which is the current recommended way, see https://objax.readthedocs.io/en/latest/advanced/io.html. | ||
|
||
```python | ||
from safejax import serialize | ||
|
||
encoded_bytes = serialize(params=model.vars()) | ||
``` | ||
|
||
Then we can just deserialize those params back using `safejax.deserialize`, and | ||
we'll end up getting the same `VarCollection` dictionary back. Note that we need | ||
to disable the unflattening with `requires_unflattening=False` as it's not required | ||
due to the way it's stored, and set `to_var_collection=True` to get a `VarCollection` | ||
instead of a `Dict[str, jnp.DeviceArray]`, even though it will work with a standard | ||
dict too. | ||
|
||
```python | ||
from safejax import deserialize | ||
|
||
decoded_params = deserialize( | ||
encoded_bytes, requires_unflattening=False, to_var_collection=True | ||
) | ||
``` | ||
|
||
Now, once decoded with `safejax.deserialize` we need to assign those key-value | ||
pais back to the `VarCollection` of the ResNet50 via assignment, as `.update` in | ||
`objax` has been redefined, see https://github.com/google/objax/blob/53b391bfa72dc59009c855d01b625049a35f5f1b/objax/variable.py#L311, | ||
and it's not consistent with the standard `dict.update` (already reported at | ||
https://github.com/google/objax/issues/254). So, instead, we need to loop over | ||
all the key-value pairs in the decoded params and assign those one by one to the | ||
`VarCollection` in `model.vars()`. | ||
|
||
```python | ||
for key, value in decoded_params.items(): | ||
if key not in model.vars(): | ||
print(f"Key {key} not in model.vars()! Skipping.") | ||
continue | ||
model.vars()[key].assign(value) | ||
``` | ||
|
||
And, finally, we can run the inference over the model via the `__call__` method | ||
as the `.vars()` are already copied from the params resulting of `safejax.deserialize`. | ||
|
||
```python | ||
from jax import numpy as jnp | ||
|
||
x = jnp.ones((1, 3, 224, 224)) | ||
y = model(x, training=False) | ||
``` | ||
|
||
Note that we're setting the `training` flag to `False`, which is the standard way | ||
of running the inference over a pre-trained model in `objax`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# 🔐 Serialize JAX, Flax, Haiku, or Objax model params with `safetensors` | ||
|
||
`safejax` is a Python package to serialize JAX, Flax, Haiku, or Objax model params using `safetensors` | ||
as the tensor storage format, instead of relying on `pickle`. For more details on why | ||
`safetensors` is safer than `pickle` please check https://github.com/huggingface/safetensors. | ||
|
||
Note that `safejax` supports the serialization of `jax`, `flax`, `dm-haiku`, and `objax` model | ||
parameters and has been tested with all those frameworks, but there may be some cases where it | ||
does not work as expected, as this is still in an early development phase, so please if you have | ||
any feedback or bug reports, open an issue at https://github.com/alvarobartt/safejax/issues |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# ⬇️ Installation | ||
|
||
```bash | ||
pip install safejax --upgrade | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# 📝 License | ||
|
||
MIT License | ||
|
||
Copyright (c) 2022-present Alvaro Bartolome <alvarobartt@yahoo.com> | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# 🛠️ Requirements | ||
|
||
`safejax` requires Python 3.7 or above |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# 💻 Usage | ||
|
||
Let's create a `flax` model using the Linen API and initialize it. | ||
|
||
```python | ||
import jax | ||
from flax import linen as nn | ||
from jax import numpy as jnp | ||
|
||
class SingleLayerModel(nn.Module): | ||
features: int | ||
|
||
@nn.compact | ||
def __call__(self, x): | ||
x = nn.Dense(features=self.features)(x) | ||
return x | ||
|
||
model = SingleLayerModel(features=1) | ||
|
||
rng = jax.random.PRNGKey(0) | ||
params = model.init(rng, jnp.ones((1, 1))) | ||
``` | ||
|
||
Once done, we can already save the model params with `safejax` (using `safetensors` | ||
storage format) using `safejax.serialize`. | ||
|
||
```python | ||
from safejax import serialize | ||
|
||
serialized_params = serialize(params=params) | ||
``` | ||
|
||
Those params can be later loaded using `safejax.deserialize` and used | ||
to run the inference over the model using those weights. | ||
|
||
```python | ||
from safejax import deserialize | ||
|
||
params = deserialize(path_or_buf=serialized_params, freeze_dict=True) | ||
``` | ||
|
||
And, finally, running the inference as: | ||
|
||
```python | ||
x = jnp.ones((1, 28, 28, 1)) | ||
y = model.apply(params, x) | ||
``` | ||
|
||
More in-detail examples can be found at [`examples/`](./examples) for `flax`, `dm-haiku`, and `objax`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# 🤔 Why `safejax`? | ||
|
||
`safetensors` defines an easy and fast (zero-copy) format to store tensors, | ||
while `pickle` has some known weaknesses and security issues. `safetensors` | ||
is also a storage format that is intended to be trivial to the framework | ||
used to load the tensors. More in-depth information can be found at | ||
https://github.com/huggingface/safetensors. | ||
|
||
Both `jax` and `haiku` use `pytrees` to store the model parameters in memory, so | ||
it's a dictionary-like class containing nested `jnp.DeviceArray` tensors. | ||
|
||
Then `objax` defines a custom dictionary-like class named `VarCollection` that contains | ||
some variables inheriting from `BaseVar` which is another custom `objax` type. | ||
|
||
`flax` defines a dictionary-like class named `FrozenDict` that is used to | ||
store the tensors in memory, it can be dumped either into `bytes` in `MessagePack` | ||
format or as a `state_dict`. | ||
|
||
Anyway, `flax` still uses `pickle` as the format for storing the tensors, so | ||
there are no plans from HuggingFace to extend `safetensors` to support anything | ||
more than tensors e.g. `FrozenDict`s, see their response at | ||
https://github.com/huggingface/safetensors/discussions/138. | ||
|
||
So `safejax` was created to easily provide a way to serialize `FrozenDict`s | ||
using `safetensors` as the tensor storage format instead of `pickle`. | ||
|
||
## 📄 Main differences with `flax.serialization` | ||
|
||
* `flax.serialization.to_bytes` uses `pickle` as the tensor storage format, while | ||
`safejax.serialize` uses `safetensors` | ||
* `flax.serialization.from_bytes` requires the `target` to be instantiated, while | ||
`safejax.deserialize` just needs the encoded bytes |
Oops, something went wrong.