Skip to content

Commit

Permalink
Add CPU test for Nano GPT Model (#1125)
Browse files Browse the repository at this point in the history
  • Loading branch information
pdeviTT authored Feb 10, 2025
1 parent 2c25e9d commit 8d6a4da
Showing 1 changed file with 71 additions and 0 deletions.
71 changes: 71 additions & 0 deletions forge/test/models/pytorch/text/nanogpt/test_nanogpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
from transformers import AutoModel, AutoTokenizer

import forge
from forge.verify.verify import verify

from test.models.utils import Framework, Source, Task, build_module_name


# Wrapper to get around attention mask
class Wrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, input_ids, attention_mask):
return self.model(input_ids, None, attention_mask)


@pytest.mark.nightly
@pytest.mark.parametrize("variant", ["FinancialSupport/NanoGPT"])
def test_nanogpt_text_generation(record_forge_property, variant):

# Build Module Name
module_name = build_module_name(
framework=Framework.PYTORCH,
model="nanogpt",
variant=variant,
task=Task.TEXT_GENERATION,
source=Source.HUGGINGFACE,
)

# Record Forge Property
record_forge_property("model_name", module_name)

# Load the model
tokenizer = AutoTokenizer.from_pretrained(variant)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModel.from_pretrained(variant, ignore_mismatched_sizes=True, use_cache=False, return_dict=False)

# Input prompt
input_prompt = "The financial market showed signs of volatility"

# Tokenize input
inputs = tokenizer(
input_prompt,
return_tensors="pt",
max_length=150,
padding=True,
truncation=True,
)
input_ids = inputs["input_ids"]
attn_mask = inputs["attention_mask"]
inputs = [input_ids, attn_mask]

framework_model = Wrapper(model)

# Forge compile framework model
compiled_model = forge.compile(
framework_model,
inputs,
module_name=module_name,
)

# Model Verification
verify(inputs, framework_model, compiled_model)

0 comments on commit 8d6a4da

Please sign in to comment.