Skip to content

Commit

Permalink
Merge pull request #3509 from Micky774:nnx_linear_api_tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 586471437
  • Loading branch information
Flax Authors committed Nov 29, 2023
2 parents 4352879 + 3024a14 commit 68333d4
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions flax/experimental/nnx/tests/nn/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
import jax.numpy as jnp
from absl.testing import parameterized
from jax.lax import Precision
from numpy.testing import assert_array_equal

from flax import linen
from flax.experimental import nnx


Expand All @@ -35,3 +40,30 @@ def test_basic_multi_features(self):
assert module.kernel.shape == (2, 3, 4)
assert module.bias is not None
assert module.bias.shape == (3, 4)


class TestLinenConsistency(parameterized.TestCase):

@parameterized.product(
use_bias = [True, False],
dtype = [jnp.float32, jnp.float16],
param_dtype = [jnp.float32, jnp.float16],
precision = [Precision.DEFAULT, Precision.HIGH, Precision.HIGHEST],
)
def test_nnx_linen_equivalence(self, **kwargs):
key = jax.random.key(42)
rngs = nnx.Rngs(42)
IN_FEATURES = 32
OUT_FEATURES = 64

x = jax.numpy.ones((1, IN_FEATURES))
model_nnx = nnx.Linear.create_abstract(IN_FEATURES, OUT_FEATURES, **kwargs, rngs=rngs)
model = linen.Dense(OUT_FEATURES, **kwargs)
variables = model.init(key, x)
model_nnx.kernel = variables['params']['kernel']
if kwargs["use_bias"]:
model_nnx.bias = variables['params']['bias']

out_nnx = model_nnx(x)
out = model.apply(variables, x)
assert_array_equal(out, out_nnx)

0 comments on commit 68333d4

Please sign in to comment.