Skip to content

Latest commit

 

History

History

examples

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 

💻 Examples

Here you will find some detailed examples of how to use safejax to serialize model parameters, in opposition to the default way to store those, which uses pickle as the format to store the tensors instead of safetensors.

To run this Python script you won't need to install anything else than safejax, as both jax and flax are installed as part of it.

In this case, a single-layer model will be created, for now, flax doesn't have any pre-defined architecture such as ResNet, but you can use flaxmodels for that, as it defines some well-known architectures written in flax.

import jax
from flax import linen as nn

class SingleLayerModel(nn.Module):
    features: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.features)(x)
        return x

Once the network has been defined, we can instantiate and initialize it, to retrieve the params out of the forward pass performed during .init.

import jax
from jax import numpy as jnp

network = SingleLayerModel(features=1)

rng_key = jax.random.PRNGKey(seed=0)
initial_params = network.init(rng_key, jnp.ones((1, 1)))

Right after getting the params from the .init method's output, we can use safejax.serialize to encode those using safetensors, that later on can be loaded back using safejax.deserialize.

from safejax import deserialize, serialize

encoded_bytes = serialize(params=initial_params)
decoded_params = deserialize(path_or_buf=encoded_bytes, freeze_dict=True)

As seen in the code above, we're using freeze_dict=True since its default value is False, as we want to freeze the dict with the params before actually returning it during safejax.deserialize, this transforms the Dict into a FrozenDict.

Finally, we can use those decoded_params to run a forward pass with the previously defined single-layer network.

x = jnp.ones((1, 1))
y = network.apply(decoded_params, x)

To run this Python script you'll need to have both safejax and dm-haiku installed.

A ResNet50 architecture will be used from haiku.nets.imagenet.resnet and since the purpose of the example is to show the integration of both dm-haiku and safejax, we won't use pre-trained weights.

If you're not familiar with dm-haiku, please visit Haiku Basics.

First of all, let's create the network instance for the ResNet50 using dm-haiku with the following code:

import haiku as hk
from jax import numpy as jnp

def resnet_fn(x: jnp.DeviceArray, is_training: bool):
    resnet = hk.nets.ResNet50(num_classes=10)
    return resnet(x, is_training=is_training)

network = hk.without_apply_rng(hk.transform_with_state(resnet_fn))

Some notes on the code above:

Then we just initialize the network to retrieve both the params and the state, which again, are random.

import jax

rng_key = jax.random.PRNGKey(seed=0)
initial_params, initial_state = network.init(
    rng_key, jnp.ones([1, 224, 224, 3]), is_training=True
)

Now once we have the params, we can import safejax.serialize to serialize the params using safetensors as the tensor storage format, that later on can be loaded back using safejax.deserialize and used for the network's inference.

from safejax import deserialize, serialize

encoded_bytes = serialize(params=initial_params)
decoded_params = deserialize(path_or_buf=encoded_bytes)

Finally, let's just use those decoded_params to run the inference over the network using those weights.

x = jnp.ones([1, 224, 224, 3])
y, _ = network.apply(decoded_params, initial_state, x, is_training=False)

To run this Python script you won't need to install anything else than safejax, as both jax and objax are installed as part of it.

In this case, we'll be using one of the architectures defined in the model zoo of objax at objax/zoo, which is ResNet50. So first of all, let's initialize it:

from objax.zoo.resnet_v2 import ResNet50

model = ResNet50(in_channels=3, num_classes=1000)

Once initialized, we can already access the model params which in objax are stored in model.vars() and are of type VarCollection which is a dictionary-like class. So on, we can already serialize those using safejax.serialize and safetensors format instead of pickle which is the current recommended way, see https://objax.readthedocs.io/en/latest/advanced/io.html.

from safejax import serialize

encoded_bytes = serialize(params=model.vars())

Then we can just deserialize those params back using safejax.deserialize, and we'll end up getting the same VarCollection dictionary back. Note that we need to disable the unflattening with requires_unflattening=False as it's not required due to the way it's stored, and set to_var_collection=True to get a VarCollection instead of a Dict[str, jnp.DeviceArray], even though it will work with a standard dict too.

from safejax import deserialize

decoded_params = deserialize(
    encoded_bytes, requires_unflattening=False, to_var_collection=True
)

Now, once decoded with safejax.deserialize we need to assign those key-value pais back to the VarCollection of the ResNet50 via assignment, as .update in objax has been redefined, see https://github.com/google/objax/blob/53b391bfa72dc59009c855d01b625049a35f5f1b/objax/variable.py#L311, and it's not consistent with the standard dict.update (already reported at google/objax#254). So, instead, we need to loop over all the key-value pairs in the decoded params and assign those one by one to the VarCollection in model.vars().

for key, value in decoded_params.items():
    if key not in model.vars():
        print(f"Key {key} not in model.vars()! Skipping.")
        continue
    model.vars()[key].assign(value)

And, finally, we can run the inference over the model via the __call__ method as the .vars() are already copied from the params resulting of safejax.deserialize.

from jax import numpy as jnp

x = jnp.ones((1, 3, 224, 224))
y = model(x, training=False)

Note that we're setting the training flag to False, which is the standard way of running the inference over a pre-trained model in objax.