Skip to content

Commit

Permalink
Add tests for BERT (#219)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgligorijevicTT authored Jan 31, 2025
1 parent d960be9 commit 3f4a9e7
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 0 deletions.
Empty file.
Empty file.
42 changes: 42 additions & 0 deletions tests/jax/models/bert/base/test_bert_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import pytest
from infra import RunMode

from ..tester import FlaxBertForMaskedLMTester

MODEL_PATH = "google-bert/bert-base-uncased"


# ----- Fixtures -----


@pytest.fixture
def inference_tester() -> FlaxBertForMaskedLMTester:
return FlaxBertForMaskedLMTester(MODEL_PATH)


@pytest.fixture
def training_tester() -> FlaxBertForMaskedLMTester:
return FlaxBertForMaskedLMTester(MODEL_PATH, RunMode.TRAINING)


# ----- Tests -----


@pytest.mark.xfail(
reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)"
)
def test_flax_bert_base_inference(
inference_tester: FlaxBertForMaskedLMTester,
):
inference_tester.test()


@pytest.mark.skip(reason="Support for training not implemented")
def test_flax_bert_base_training(
training_tester: FlaxBertForMaskedLMTester,
):
training_tester.test()
Empty file.
42 changes: 42 additions & 0 deletions tests/jax/models/bert/large/test_bert_large.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import pytest
from infra import RunMode

from ..tester import FlaxBertForMaskedLMTester

MODEL_PATH = "google-bert/bert-large-uncased"


# ----- Fixtures -----


@pytest.fixture
def inference_tester() -> FlaxBertForMaskedLMTester:
return FlaxBertForMaskedLMTester(MODEL_PATH)


@pytest.fixture
def training_tester() -> FlaxBertForMaskedLMTester:
return FlaxBertForMaskedLMTester(MODEL_PATH, RunMode.TRAINING)


# ----- Tests -----


@pytest.mark.xfail(
reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)"
)
def test_flax_bert_large_inference(
inference_tester: FlaxBertForMaskedLMTester,
):
inference_tester.test()


@pytest.mark.skip(reason="Support for training not implemented")
def test_flax_bert_large_training(
training_tester: FlaxBertForMaskedLMTester,
):
training_tester.test()
41 changes: 41 additions & 0 deletions tests/jax/models/bert/tester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, Sequence

import jax
from flax import linen as nn
from infra import ComparisonConfig, ModelTester, RunMode
from transformers import AutoTokenizer, FlaxBertForMaskedLM


class FlaxBertForMaskedLMTester(ModelTester):
"""Tester for BERT model variants on masked language modeling task."""

def __init__(
self,
model_name: str,
comparison_config: ComparisonConfig = ComparisonConfig(),
run_mode: RunMode = RunMode.INFERENCE,
) -> None:
self._model_name = model_name
super().__init__(comparison_config, run_mode)

# @override
def _get_model(self) -> nn.Module:
return FlaxBertForMaskedLM.from_pretrained(self._model_name)

# @override
def _get_input_activations(self) -> Sequence[jax.Array]:
tokenizer = AutoTokenizer.from_pretrained(self._model_name)
inputs = tokenizer("Hello [MASK]", return_tensors="np")
return inputs["input_ids"]

# @override
def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]:
assert hasattr(self._model, "params")
return {
"params": self._model.params,
"input_ids": self._get_input_activations(),
}

0 comments on commit 3f4a9e7

Please sign in to comment.