From 3f4a9e7e4dd714ea22cbb80f0ad7292593605a09 Mon Sep 17 00:00:00 2001 From: Stefan Gligorijevic <189116645+sgligorijevicTT@users.noreply.github.com> Date: Fri, 31 Jan 2025 19:00:05 +0100 Subject: [PATCH] Add tests for BERT (#219) --- tests/jax/models/bert/__init__.py | 0 tests/jax/models/bert/base/__init__.py | 0 tests/jax/models/bert/base/test_bert_base.py | 42 +++++++++++++++++++ tests/jax/models/bert/large/__init__.py | 0 .../jax/models/bert/large/test_bert_large.py | 42 +++++++++++++++++++ tests/jax/models/bert/tester.py | 41 ++++++++++++++++++ 6 files changed, 125 insertions(+) create mode 100644 tests/jax/models/bert/__init__.py create mode 100644 tests/jax/models/bert/base/__init__.py create mode 100644 tests/jax/models/bert/base/test_bert_base.py create mode 100644 tests/jax/models/bert/large/__init__.py create mode 100644 tests/jax/models/bert/large/test_bert_large.py create mode 100644 tests/jax/models/bert/tester.py diff --git a/tests/jax/models/bert/__init__.py b/tests/jax/models/bert/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/bert/base/__init__.py b/tests/jax/models/bert/base/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/bert/base/test_bert_base.py b/tests/jax/models/bert/base/test_bert_base.py new file mode 100644 index 00000000..d4552f85 --- /dev/null +++ b/tests/jax/models/bert/base/test_bert_base.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from infra import RunMode + +from ..tester import FlaxBertForMaskedLMTester + +MODEL_PATH = "google-bert/bert-base-uncased" + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> FlaxBertForMaskedLMTester: + return FlaxBertForMaskedLMTester(MODEL_PATH) + + +@pytest.fixture +def training_tester() -> FlaxBertForMaskedLMTester: + return FlaxBertForMaskedLMTester(MODEL_PATH, RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.xfail( + reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)" +) +def test_flax_bert_base_inference( + inference_tester: FlaxBertForMaskedLMTester, +): + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_flax_bert_base_training( + training_tester: FlaxBertForMaskedLMTester, +): + training_tester.test() diff --git a/tests/jax/models/bert/large/__init__.py b/tests/jax/models/bert/large/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/bert/large/test_bert_large.py b/tests/jax/models/bert/large/test_bert_large.py new file mode 100644 index 00000000..d7bf1fc3 --- /dev/null +++ b/tests/jax/models/bert/large/test_bert_large.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from infra import RunMode + +from ..tester import FlaxBertForMaskedLMTester + +MODEL_PATH = "google-bert/bert-large-uncased" + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> FlaxBertForMaskedLMTester: + return FlaxBertForMaskedLMTester(MODEL_PATH) + + +@pytest.fixture +def training_tester() -> FlaxBertForMaskedLMTester: + return FlaxBertForMaskedLMTester(MODEL_PATH, RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.xfail( + reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)" +) +def test_flax_bert_large_inference( + inference_tester: FlaxBertForMaskedLMTester, +): + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_flax_bert_large_training( + training_tester: FlaxBertForMaskedLMTester, +): + training_tester.test() diff --git a/tests/jax/models/bert/tester.py b/tests/jax/models/bert/tester.py new file mode 100644 index 00000000..e96f83c5 --- /dev/null +++ b/tests/jax/models/bert/tester.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, Sequence + +import jax +from flax import linen as nn +from infra import ComparisonConfig, ModelTester, RunMode +from transformers import AutoTokenizer, FlaxBertForMaskedLM + + +class FlaxBertForMaskedLMTester(ModelTester): + """Tester for BERT model variants on masked language modeling task.""" + + def __init__( + self, + model_name: str, + comparison_config: ComparisonConfig = ComparisonConfig(), + run_mode: RunMode = RunMode.INFERENCE, + ) -> None: + self._model_name = model_name + super().__init__(comparison_config, run_mode) + + # @override + def _get_model(self) -> nn.Module: + return FlaxBertForMaskedLM.from_pretrained(self._model_name) + + # @override + def _get_input_activations(self) -> Sequence[jax.Array]: + tokenizer = AutoTokenizer.from_pretrained(self._model_name) + inputs = tokenizer("Hello [MASK]", return_tensors="np") + return inputs["input_ids"] + + # @override + def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]: + assert hasattr(self._model, "params") + return { + "params": self._model.params, + "input_ids": self._get_input_activations(), + }