Skip to content

Commit

Permalink
Merge branch 'main' into umales/mnist_mse
Browse files Browse the repository at this point in the history
  • Loading branch information
umalesTT authored Jan 31, 2025
2 parents d4386ac + 1b68a3c commit f3236bb
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 1 deletion.
Empty file.
22 changes: 22 additions & 0 deletions tests/jax/models/mnist/mlp/model_implementation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from flax import linen as nn


class MNISTMLPModel(nn.Module):
hidden_sizes: tuple[int]

@nn.compact
def __call__(self, x):
x = x.reshape((x.shape[0], -1))

for h in self.hidden_sizes:
x = nn.Dense(features=h)(x)
x = nn.relu(x)

x = nn.Dense(features=10)(x)
x = nn.softmax(x)

return x
89 changes: 89 additions & 0 deletions tests/jax/models/mnist/mlp/test_mnist_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from typing import Sequence

import jax
import pytest
from flax import linen as nn
from infra import ComparisonConfig, ModelTester, RunMode

from .model_implementation import MNISTMLPModel


class MNISTMLPTester(ModelTester):
"""Tester for MNIST MLP model."""

def __init__(
self,
hidden_sizes: Sequence[int],
comparison_config: ComparisonConfig = ComparisonConfig(),
run_mode: RunMode = RunMode.INFERENCE,
) -> None:
self._hidden_sizes = hidden_sizes
super().__init__(comparison_config, run_mode)

# @override
def _get_model(self) -> nn.Module:
return MNISTMLPModel(self._hidden_sizes)

# @override
def _get_forward_method_name(self) -> str:
return "apply"

# @override
def _get_input_activations(self) -> Sequence[jax.Array]:
key = jax.random.PRNGKey(37)
img = jax.random.normal(key, (4, 28, 28, 1)) # B, H, W, C
# Channels is 1 as MNIST is in grayscale.
return img

# @override
def _get_forward_method_args(self):
inp = self._get_input_activations()

parameters = self._model.init(jax.random.PRNGKey(42), inp)

return [parameters, inp]


# ----- Fixtures -----


@pytest.fixture
def inference_tester(request) -> MNISTMLPTester:
return MNISTMLPTester(request.param)


@pytest.fixture
def training_tester(request) -> MNISTMLPTester:
return MNISTMLPTester(request.param, run_mode=RunMode.TRAINING)


# ----- Tests -----


@pytest.mark.parametrize(
"inference_tester",
[
(128,),
(128, 128),
(192, 128),
(512, 512),
(128, 128, 128),
(256, 128, 64),
],
indirect=True,
)
def test_mnist_inference(
inference_tester: MNISTMLPTester,
):
inference_tester.test()


@pytest.mark.skip(reason="Support for training not implemented")
def test_mnist_training(
training_tester: MNISTMLPTester,
):
training_tester.test()
Empty file.
3 changes: 2 additions & 1 deletion tests/jax/models/squeezebert/test_squeezebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from flax import linen as nn
from huggingface_hub import hf_hub_download
from infra import ModelTester, RunMode
from model_implementation import SqueezeBertConfig, SqueezeBertForMaskedLM
from transformers import AutoTokenizer

from .model_implementation import SqueezeBertConfig, SqueezeBertForMaskedLM

MODEL_PATH = "squeezebert/squeezebert-uncased"

# ----- Tester -----
Expand Down

0 comments on commit f3236bb

Please sign in to comment.