From af22c1bc4ab2dc79b3b0373171083a2e5338f4f0 Mon Sep 17 00:00:00 2001 From: Neil Girdhar Date: Tue, 27 Aug 2024 03:06:33 -0400 Subject: [PATCH] Add a simple test in case for Flax modules --- tests/test_flax_module.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 tests/test_flax_module.py diff --git a/tests/test_flax_module.py b/tests/test_flax_module.py new file mode 100644 index 00000000..070f7e22 --- /dev/null +++ b/tests/test_flax_module.py @@ -0,0 +1,32 @@ +import jax.numpy as jnp +import pytest +from flax import nnx +from jax import Array, vmap + +from tjax.dataclasses import DataClassModule + +try: + from flax import nnx +except ImportError: + pytest.skip("Skipping NNX graph test", allow_module_level=True) + + +@pytest.mark.skip +def test_dataclass_module() -> None: + class SomeModule(nnx.Module): + def __init__(self, epsilon: Array): + super().__init__() + self.epsilon = epsilon + + class SomeDataclassModule(DataClassModule): + def __init__(self, rngs: nnx.Rngs) -> None: + super().__init__(rngs=rngs) + self.sm = SomeModule(jnp.zeros(1)) + + def f(m: SomeDataclassModule, x: Array) -> None: + pass + + rngs = nnx.Rngs() + module = SomeDataclassModule(rngs) + z = jnp.zeros(10) + vmap(f, in_axes=(None, 0))(module, z)