Skip to content

Commit

Permalink
Add a simple test in case for Flax modules
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Aug 27, 2024
1 parent 55dac85 commit af22c1b
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions tests/test_flax_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import jax.numpy as jnp
import pytest
from flax import nnx
from jax import Array, vmap

from tjax.dataclasses import DataClassModule

try:
from flax import nnx
except ImportError:
pytest.skip("Skipping NNX graph test", allow_module_level=True)


@pytest.mark.skip
def test_dataclass_module() -> None:
class SomeModule(nnx.Module):
def __init__(self, epsilon: Array):
super().__init__()
self.epsilon = epsilon

class SomeDataclassModule(DataClassModule):
def __init__(self, rngs: nnx.Rngs) -> None:
super().__init__(rngs=rngs)
self.sm = SomeModule(jnp.zeros(1))

def f(m: SomeDataclassModule, x: Array) -> None:
pass

rngs = nnx.Rngs()
module = SomeDataclassModule(rngs)
z = jnp.zeros(10)
vmap(f, in_axes=(None, 0))(module, z)

0 comments on commit af22c1b

Please sign in to comment.