Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🧪 Add more unit test cases e.g. ResNet FrozenDict #3

Merged
merged 14 commits into from
Dec 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci-cd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ jobs:
- name: check-quality
run: |
ruff src tests
black --check --diff --preview src tests
ruff src tests benchmarks examples
black --check --diff --preview src tests benchmarks examples
run-tests:
needs: check-quality
Expand Down
43 changes: 29 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ model = SingleLayerModel(features=1)
rng = jax.random.PRNGKey(0)
params = model.init(rng, jnp.ones((1, 1)))

serialized = serialize(frozen_or_unfrozen_dict=params)
serialized = serialize(params=params)
assert isinstance(serialized, bytes)
assert len(serialized) > 0
```
Expand Down Expand Up @@ -72,27 +72,42 @@ using `safetensors` as the tensor storage format instead of `pickle`.

## 🏋🏼 Benchmark

Benchmarks use [`hyperfine`](https://github.com/sharkdp/hyperfine) so it needs
Benchmarks are no longer running with [`hyperfine`](https://github.com/sharkdp/hyperfine),
as most of the elapsed time is not during the actual serialization but in the imports and
in the model parameter initialization. So we've refactored those so as to run with pure
Python code using `time.perf_counter` to measure the elapsed time in seconds.

```bash
$ python benchmarks/resnet50.py
safejax (100 runs): 2.0974 s
flax (100 runs): 4.8734 s
```

This means that for `ResNet50`, `safejax` is x2.3 times faster than `flax.serialization` when
it comes to serialization, also to restate the fact that `safejax` stores the tensors with
`safetensors` while `flax` saves those with `pickle`.

But if we use [`hyperfine`](https://github.com/sharkdp/hyperfine) as mentioned above, it needs
to be installed first, and the `hatch`/`pyenv` environment needs to be activated
first (or just install the requirements).
first (or just install the requirements). But, due to the overhead of the script, the
elapsed time during the serialization will be minimal compared to the rest, so the overall
result won't reflect well enough the efficiency diff between both approaches, as above.

```bash
$ hyperfine --warmup 2 "python benchmark.py benchmark_safejax" "python benchmark.py benchmark_flax"
Benchmark 1: python benchmark.py benchmark_safejax
Time (mean ± σ): 539.6 ms ± 11.9 ms [User: 1693.2 ms, System: 690.4 ms]
Range (min … max): 516.1 ms555.7 ms 10 runs
$ hyperfine --warmup 2 "python benchmarks/hyperfine/resnet50.py serialization_safejax" "python benchmarks/hyperfine/resnet50.py serialization_flax"
Benchmark 1: python benchmarks/hyperfine/resnet50.py serialization_safejax
Time (mean ± σ): 1.778 s ± 0.038 s [User: 3.345 s, System: 0.511 s]
Range (min … max): 1.741 s 1.877 s 10 runs

Benchmark 2: python benchmark.py benchmark_flax
Time (mean ± σ): 543.2 ms ± 5.6 ms [User: 1659.6 ms, System: 748.9 ms]
Range (min … max): 532.0 ms551.5 ms 10 runs
Benchmark 2: python benchmarks/hyperfine/resnet50.py serialization_flax
Time (mean ± σ): 1.790 s ± 0.011 s [User: 3.371 s, System: 0.478 s]
Range (min … max): 1.771 s 1.810 s 10 runs

Summary
'python benchmark.py benchmark_safejax' ran
1.01 ± 0.02 times faster than 'python benchmark.py benchmark_flax'
'python benchmarks/hyperfine/resnet50.py serialization_safejax' ran
1.01 ± 0.02 times faster than 'python benchmarks/hyperfine/resnet50.py serialization_flax'
```

As we can see the difference is almost not noticeable, since the benchmark is using a
2-tensor dictionary, which should be faster using any method. The main difference is on
the `safetensors` usage for the tensor storage instead of `pickle`.

More in detailed and complex benchmarks will be prepared soon!
27 changes: 27 additions & 0 deletions benchmarks/hyperfine/resnet50.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import sys

import jax
from flax.serialization import to_bytes
from flaxmodels.resnet import ResNet50
from jax import numpy as jnp

from safejax.flax import serialize

resnet50 = ResNet50()
params = resnet50.init(jax.random.PRNGKey(42), jnp.ones((1, 224, 224, 3)))


def serialization_safejax():
_ = serialize(params)


def serialization_flax():
_ = to_bytes(params)


if __name__ == "__main__":
if len(sys.argv) < 2:
raise ValueError("Please provide a function name to run as an argument")
if sys.argv[1] not in globals():
raise ValueError(f"Function {sys.argv[1]} not found")
globals()[sys.argv[1]]()
8 changes: 6 additions & 2 deletions benchmark.py → benchmarks/hyperfine/single_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,17 @@ def __call__(self, x):
params = model.init(rng, jnp.ones((1, 1)))


def benchmark_safejax():
def serialization_safejax():
_ = serialize(params)


def benchmark_flax():
def serialization_flax():
_ = to_bytes(params)


if __name__ == "__main__":
if len(sys.argv) < 2:
raise ValueError("Please provide a function name to run as an argument")
if sys.argv[1] not in globals():
raise ValueError(f"Function {sys.argv[1]} not found")
globals()[sys.argv[1]]()
24 changes: 24 additions & 0 deletions benchmarks/resnet50.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from time import perf_counter

import jax
from flax.serialization import to_bytes
from flaxmodels.resnet import ResNet50
from jax import numpy as jnp

from safejax.flax import serialize

resnet50 = ResNet50()
params = resnet50.init(jax.random.PRNGKey(42), jnp.ones((1, 224, 224, 3)))


start_time = perf_counter()
for _ in range(100):
serialize(params)
end_time = perf_counter()
print(f"safejax (100 runs): {end_time - start_time:0.4f} s")

start_time = perf_counter()
for _ in range(100):
to_bytes(params)
end_time = perf_counter()
print(f"flax (100 runs): {end_time - start_time:0.4f} s")
36 changes: 36 additions & 0 deletions benchmarks/single_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from time import perf_counter

import jax
from flax import linen as nn
from flax.serialization import to_bytes
from jax import numpy as jnp

from safejax.flax import serialize


class SingleLayerModel(nn.Module):
features: int

@nn.compact
def __call__(self, x):
x = nn.Dense(features=self.features)(x)
return x


model = SingleLayerModel(features=1)

rng = jax.random.PRNGKey(0)
params = model.init(rng, jnp.ones((1, 1)))


start_time = perf_counter()
for _ in range(100):
serialize(params)
end_time = perf_counter()
print(f"safejax (100 runs): {end_time - start_time:0.4f} s")

start_time = perf_counter()
for _ in range(100):
to_bytes(params)
end_time = perf_counter()
print(f"flax (100 runs): {end_time - start_time:0.4f} s")
1 change: 0 additions & 1 deletion examples/serialization_with_flax.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import jax
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.serialization import from_bytes, to_bytes
from jax import numpy as jnp

Expand Down
2 changes: 1 addition & 1 deletion examples/serialization_with_safejax.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __call__(self, x):
rng = jax.random.PRNGKey(0)
params = model.init(rng, jnp.ones((1, 1)))

serialized = serialize(frozen_or_unfrozen_dict=params)
serialized = serialize(params=params)
assert isinstance(serialized, bytes)
assert len(serialized) > 0

Expand Down
13 changes: 9 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ Source = "https://github.com/alvarobartt/safejax"
[tool.hatch.version]
path = "src/safejax/__init__.py"

[tool.hatch.metadata]
allow-direct-references = true

[project.optional-dependencies]
quality = [
"black~=22.10.0",
Expand All @@ -44,6 +47,8 @@ quality = [
]
tests = [
"pytest~=7.1.2",
"pytest-lazy-fixture~=0.6.3",
"flaxmodels @ git+https://github.com/matthias-wright/flaxmodels.git",
]

[tool.hatch.envs.quality]
Expand All @@ -53,12 +58,12 @@ features = [

[tool.hatch.envs.quality.scripts]
check = [
"ruff src tests",
"black --check --diff --preview src tests",
"ruff src tests benchmarks examples",
"black --check --diff --preview src tests benchmarks examples",
]
format = [
"ruff --fix src tests",
"black --preview src tests",
"ruff --fix src tests benchmarks examples",
"black --preview src tests benchmarks examples",
"check",
]

Expand Down
2 changes: 1 addition & 1 deletion src/safejax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""`safejax `: Serialize JAX/Flax models with `safetensors`"""

__author__ = "Alvaro Bartolome <alvarobartt@yahoo.com>"
__version__ = "0.1.0"
__version__ = "0.1.1"
14 changes: 7 additions & 7 deletions src/safejax/flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def flatten_dict(
frozen_or_unfrozen_dict: Union[Dict[str, Any], FrozenDict],
params: Union[Dict[str, Any], FrozenDict],
key_prefix: Union[str, None] = None,
) -> Union[Dict[str, jnp.DeviceArray], Dict[str, np.ndarray]]:
"""
Expand All @@ -27,25 +27,25 @@ def flatten_dict(
Reference at https://gist.github.com/Narsil/d5b0d747e5c8c299eb6d82709e480e3d
Args:
frozen_or_unfrozen_dict: A `FrozenDict` or a `Dict` containing the model parameters.
params: A `FrozenDict` or a `Dict` containing the model parameters.
key_prefix: A prefix to prepend to the keys of the flattened dictionary.
Returns:
A flattened dictionary containing the model parameters.
"""
weights = {}
for key, value in frozen_or_unfrozen_dict.items():
for key, value in params.items():
key = f"{key_prefix}.{key}" if key_prefix else key
if isinstance(value, jnp.DeviceArray) or isinstance(value, np.ndarray):
weights[key] = value
continue
if isinstance(value, FrozenDict) or isinstance(value, Dict):
weights.update(flatten_dict(frozen_or_unfrozen_dict=value, key_prefix=key))
weights.update(flatten_dict(params=value, key_prefix=key))
return weights


def serialize(
frozen_or_unfrozen_dict: Union[Dict[str, Any], FrozenDict],
params: Union[Dict[str, Any], FrozenDict],
filename: Union[PathLike, None] = None,
) -> Union[bytes, PathLike]:
"""
Expand All @@ -55,13 +55,13 @@ def serialize(
otherwise the model is saved to the provided `filename` and the `filename` is returned.
Args:
frozen_or_unfrozen_dict: A `FrozenDict` or a `Dict` containing the model parameters.
params: A `FrozenDict` or a `Dict` containing the model parameters.
filename: The path to the file where the model will be saved.
Returns:
The serialized model as a `bytes` object or the path to the file where the model was saved.
"""
flattened_dict = flatten_dict(frozen_or_unfrozen_dict=frozen_or_unfrozen_dict)
flattened_dict = flatten_dict(params=params)
if not filename:
return save(tensors=flattened_dict)
else:
Expand Down
23 changes: 18 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import pytest
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flaxmodels.resnet import ResNet50


class SingleLayerModel(nn.Module):
class SingleLayer(nn.Module):
features: int

@nn.compact
Expand All @@ -17,15 +18,27 @@ def __call__(self, x):


@pytest.fixture
def single_layer_model() -> nn.Module:
return SingleLayerModel(features=1)
def single_layer() -> nn.Module:
return SingleLayer(features=1)


@pytest.fixture
def single_layer_model_params(single_layer_model: nn.Module) -> FrozenDict:
def single_layer_params(single_layer: nn.Module) -> FrozenDict:
# https://docs.pytest.org/en/7.1.x/how-to/fixtures.html#fixtures-can-request-other-fixtures
rng = jax.random.PRNGKey(0)
params = single_layer_model.init(rng, jnp.ones((1, 1)))
params = single_layer.init(rng, jnp.ones((1, 1)))
return params


@pytest.fixture
def resnet50() -> nn.Module:
return ResNet50()


@pytest.fixture
def resnet50_params(resnet50: nn.Module) -> FrozenDict:
rng = jax.random.PRNGKey(0)
params = resnet50.init(rng, jnp.ones((1, 224, 224, 3)))
return params


Expand Down
Loading