Skip to content

Commit

Permalink
Rename all the NNX tests to internal naming & build conventions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 638706486
  • Loading branch information
IvyZX authored and Flax Authors committed May 30, 2024
1 parent 6e63497 commit 3f47c1d
Show file tree
Hide file tree
Showing 23 changed files with 108 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@

import dataclasses

from absl.testing import absltest
import jax
import jax.numpy as jnp

from flax import nnx
from flax.nnx import compat


class TestCompatModule:
class TestCompatModule(absltest.TestCase):
def test_compact_basic(self):
class Linear(compat.Module):
dout: int
Expand Down Expand Up @@ -131,4 +132,7 @@ def __call__(self, x):
assert y.shape == (1, 5)

assert hasattr(bar, 'foo')
assert isinstance(bar.foo, Foo)
assert isinstance(bar.foo, Foo)

if __name__ == '__main__':
absltest.main()
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@


from flax import nnx
from absl.testing import absltest


class TestContainers:
class TestContainers(absltest.TestCase):
def test_unbox(self):
x = nnx.Param(
1,
Expand Down Expand Up @@ -58,3 +59,7 @@ def __init__(self) -> None:

assert module.x.value == 12
assert vars(module)['x'].raw_value == 12


if __name__ == '__main__':
absltest.main()
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import jax.numpy as jnp
import optax

from numpy.testing import assert_array_equal
from absl.testing import absltest
import numpy as np

from flax import linen
from flax import nnx
Expand Down Expand Up @@ -90,7 +91,7 @@ def test_nnx_linen_sequential_equivalence(self):
).value
out_nnx = model_nnx(x)
out = model.apply(variables, x)
assert_array_equal(out, out_nnx)
np.testing.assert_array_equal(out, out_nnx)

variables = model.init(key2, x)
for layer_index in range(2):
Expand All @@ -100,4 +101,9 @@ def test_nnx_linen_sequential_equivalence(self):
][f'layers_{layer_index}'][param]
out_nnx = model_nnx(x)
out = model.apply(variables, x)
assert_array_equal(out, out_nnx)
np.testing.assert_array_equal(out, out_nnx)


if __name__ == '__main__':
absltest.main()

5 changes: 5 additions & 0 deletions flax/nnx/tests/test_ids.py → flax/nnx/tests/ids_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import copy

from absl.testing import absltest
from flax.nnx.nnx import ids


Expand All @@ -28,3 +29,7 @@ def test_hashable(self):
id1dc = copy.deepcopy(id1)
assert hash(id1) != hash(id1c)
assert hash(id1) != hash(id1dc)


if __name__ == '__main__':
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import typing as tp

from absl.testing import absltest
import jax
import jax.numpy as jnp
import numpy as np
Expand All @@ -23,7 +24,7 @@
A = tp.TypeVar('A')


class TestIntegration:
class TestIntegration(absltest.TestCase):
def test_shared_modules(self):
class Block(nnx.Module):
def __init__(self, linear: nnx.Linear, *, rngs):
Expand Down Expand Up @@ -257,3 +258,7 @@ def __call__(self, x):
intermediates, state = state.split(nnx.Intermediate, ...)

assert 'y' in intermediates


if __name__ == '__main__':
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from flax import nnx

from absl.testing import absltest
from absl.testing import parameterized


Expand Down Expand Up @@ -63,4 +64,8 @@ def test_multimetric(self):
metrics.reset()
values = metrics.compute()
self.assertTrue(jnp.isnan(values['accuracy']))
self.assertTrue(jnp.isnan(values['loss']))
self.assertTrue(jnp.isnan(values['loss']))


if __name__ == '__main__':
absltest.main()
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
from flax import nnx
from flax.typing import Dtype, PrecisionLike

from numpy.testing import assert_array_equal
import numpy as np

import typing as tp
from absl.testing import parameterized
from absl.testing import absltest


class TestMultiHeadAttention:
class TestMultiHeadAttention(absltest.TestCase):
def test_basic(self):
module = nnx.MultiHeadAttention(
num_heads=2,
Expand Down Expand Up @@ -167,4 +168,8 @@ def test_nnx_attention_equivalence(

out_nnx = model_nnx(x)
out, cache = model.apply(variables, x, mutable=['cache'])
assert_array_equal(out, out_nnx)
np.testing.assert_array_equal(out, out_nnx)


if __name__ == '__main__':
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
import typing as tp

import jax
from absl.testing import absltest
from absl.testing import parameterized
from jax import numpy as jnp
from jax.lax import Precision
from numpy.testing import assert_array_equal
import numpy as np

from flax import linen
from flax import nnx
Expand Down Expand Up @@ -102,7 +103,7 @@ def test_nnx_linen_conv_equivalence(

out_nnx = model_nnx(x)
out = model.apply(variables, x)
assert_array_equal(out, out_nnx)
np.testing.assert_array_equal(out, out_nnx)

@parameterized.product(
strides=[None, (2, 3)],
Expand Down Expand Up @@ -166,4 +167,8 @@ def test_nnx_linen_convtranspose_equivalence(

out_nnx = model_nnx(x)
out = model.apply(variables, x)
assert_array_equal(out, out_nnx)
np.testing.assert_array_equal(out, out_nnx)


if __name__ == '__main__':
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
import typing as tp

import jax
from absl.testing import absltest
from absl.testing import parameterized
from jax import numpy as jnp
from numpy.testing import assert_array_equal
import numpy as np

from flax import linen
from flax import nnx
Expand Down Expand Up @@ -62,11 +63,15 @@ def test_nnx_linen_equivalence(

out_nnx = model_nnx(x)
out = model.apply(variables, x)
assert_array_equal(out, out_nnx)
np.testing.assert_array_equal(out, out_nnx)

x = jax.numpy.ones((10,), dtype=input_dtype) * 10
out_nnx = model_nnx(x)
out = model.apply(variables, x)
assert isinstance(out, jax.Array)
assert_array_equal(out, out_nnx)
assert_array_equal(jax.numpy.isnan(out).all(), jax.numpy.array([True]))
np.testing.assert_array_equal(out, out_nnx)
np.testing.assert_array_equal(jax.numpy.isnan(out).all(), jax.numpy.array([True]))


if __name__ == '__main__':
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

import jax
import jax.numpy as jnp
from absl.testing import absltest
from absl.testing import parameterized
from jax.lax import Precision
from numpy.testing import assert_array_equal
import numpy as np

from flax import linen
from flax import nnx
Expand Down Expand Up @@ -91,7 +92,7 @@ def test_nnx_linear_equivalence(

out_nnx = model_nnx(x)
out = model.apply(variables, x)
assert_array_equal(out, out_nnx)
np.testing.assert_array_equal(out, out_nnx)

@parameterized.product(
einsum_str=['defab,bcef->adefc', 'd...ab,bc...->ad...c'],
Expand Down Expand Up @@ -139,7 +140,7 @@ def test_nnx_einsum_equivalence(
variables['params']['bias'] = model_nnx.bias.value
out_nnx = model_nnx(x)
out = model.apply(variables, x)
assert_array_equal(out, out_nnx)
np.testing.assert_array_equal(out, out_nnx)

variables = model.init(key, x)
model_nnx.kernel.value = variables['params']['kernel']
Expand All @@ -148,4 +149,8 @@ def test_nnx_einsum_equivalence(
model_nnx.bias.value = variables['params']['bias']
out_nnx = model_nnx(x)
out = model.apply(variables, x)
assert_array_equal(out, out_nnx)
np.testing.assert_array_equal(out, out_nnx)


if __name__ == '__main__':
absltest.main()
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

import jax
import jax.numpy as jnp
from absl.testing import absltest
from absl.testing import parameterized
from numpy.testing import assert_array_equal
import numpy as np

from flax import linen
from flax import nnx
Expand All @@ -29,14 +30,14 @@ class TestLinenConsistency(parameterized.TestCase):
dtype=[jnp.float32, jnp.float16],
param_dtype=[jnp.float32, jnp.float16],
use_fast_variance=[True, False],
mask=[None, jnp.array([True, False, True, False, True])],
mask=[None, np.array([True, False, True, False, True])],
)
def test_nnx_linen_batchnorm_equivalence(
self,
dtype: tp.Optional[Dtype],
param_dtype: Dtype,
use_fast_variance: bool,
mask: tp.Optional[jax.Array],
mask: tp.Optional[np.ndarray],
):
class NNXModel(nnx.Module):
def __init__(self, dtype, param_dtype, use_fast_variance, rngs):
Expand Down Expand Up @@ -99,20 +100,20 @@ def __call__(self, x, *, mask=None):
nnx_model.linear.bias.value = variables['params']['linear']['bias']

nnx_out = nnx_model(x, mask=mask)
assert_array_equal(linen_out, nnx_out)
np.testing.assert_array_equal(linen_out, nnx_out)

@parameterized.product(
dtype=[jnp.float32, jnp.float16],
param_dtype=[jnp.float32, jnp.float16],
use_fast_variance=[True, False],
mask=[None, jnp.array([True, False, True, False, True])],
mask=[None, np.array([True, False, True, False, True])],
)
def test_nnx_linen_layernorm_equivalence(
self,
dtype: tp.Optional[Dtype],
param_dtype: Dtype,
use_fast_variance: bool,
mask: tp.Optional[jax.Array],
mask: tp.Optional[np.ndarray],
):
class NNXModel(nnx.Module):
def __init__(self, dtype, param_dtype, use_fast_variance, rngs):
Expand Down Expand Up @@ -171,20 +172,20 @@ def __call__(self, x, *, mask=None):
nnx_model.linear.bias.value = variables['params']['linear']['bias']

nnx_out = nnx_model(x, mask=mask)
assert_array_equal(linen_out, nnx_out)
np.testing.assert_array_equal(linen_out, nnx_out)

@parameterized.product(
dtype=[jnp.float32, jnp.float16],
param_dtype=[jnp.float32, jnp.float16],
use_fast_variance=[True, False],
mask=[None, jnp.array([True, False, True, False, True])],
mask=[None, np.array([True, False, True, False, True])],
)
def test_nnx_linen_rmsnorm_equivalence(
self,
dtype: tp.Optional[Dtype],
param_dtype: Dtype,
use_fast_variance: bool,
mask: tp.Optional[jax.Array],
mask: tp.Optional[np.ndarray],
):
class NNXModel(nnx.Module):
def __init__(self, dtype, param_dtype, use_fast_variance, rngs):
Expand Down Expand Up @@ -243,4 +244,8 @@ def __call__(self, x, *, mask=None):
nnx_model.linear.bias.value = variables['params']['linear']['bias']

nnx_out = nnx_model(x, mask=mask)
assert_array_equal(linen_out, nnx_out)
np.testing.assert_array_equal(linen_out, nnx_out)


if __name__ == '__main__':
absltest.main()
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from flax import nnx

from absl.testing import absltest
from absl.testing import parameterized


Expand Down Expand Up @@ -117,4 +118,8 @@ def update(self, *, grads, **updates): # type: ignore[signature-mismatch]
state.update(grads=grads, values=loss_fn(state.model))
initial_loss = state.metrics.compute()
state.update(grads=grads, values=loss_fn(state.model))
self.assertTrue(state.metrics.compute() < initial_loss)
self.assertTrue(state.metrics.compute() < initial_loss)


if __name__ == '__main__':
absltest.main()
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 3f47c1d

Please sign in to comment.