Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added NNX/Linen API consistency test for Linear/Dense layer #3509

Merged
merged 2 commits into from
Nov 29, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions flax/experimental/nnx/tests/nn/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
import jax.numpy as jnp
from absl.testing import parameterized
from jax.lax import Precision
from numpy.testing import assert_array_equal

from flax import linen
from flax.experimental import nnx


Expand All @@ -35,3 +40,30 @@ def test_basic_multi_features(self):
assert module.kernel.shape == (2, 3, 4)
assert module.bias is not None
assert module.bias.shape == (3, 4)


class TestLinenConsistency(parameterized.TestCase):

@parameterized.product(
use_bias = [True, False],
dtype = [jnp.float32, jnp.float16],
param_dtype = [jnp.float32, jnp.float16],
precision = [Precision.DEFAULT, Precision.HIGH, Precision.HIGHEST],
)
def test_nnx_linen_equivalence(self, **kwargs):
key = jax.random.key(42)
rngs = nnx.Rngs(42)
IN_FEATURES = 32
OUT_FEATURES = 64

x = jax.numpy.ones((1, IN_FEATURES))
model_nnx = nnx.Linear.create_abstract(IN_FEATURES, OUT_FEATURES, **kwargs, rngs=rngs)
model = linen.Dense(OUT_FEATURES, **kwargs)
variables = model.init(key, x)
model_nnx.kernel = variables['params']['kernel']
if kwargs["use_bias"]:
model_nnx.bias = variables['params']['bias']

out_nnx = model_nnx(x)
out = model.apply(variables, x)
assert_array_equal(out, out_nnx)