Skip to content

Commit

Permalink
Add test for deepseek_math
Browse files Browse the repository at this point in the history
  • Loading branch information
meenakshiramanathan1 committed Feb 4, 2025
1 parent 8d30bab commit 1a22034
Showing 1 changed file with 105 additions and 0 deletions.
105 changes: 105 additions & 0 deletions forge/test/models/pytorch/multimodal/deepseek/test_deepseek_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

import forge

from test.models.utils import Framework, Source, Task, build_module_name


class Wrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, input_tensor):
return self.model(input_tensor, max_new_tokens=100).logits


@pytest.mark.skip(reason="requirement of host DRAM during compile time")
@pytest.mark.parametrize("variant", ["deepseek-math-7b-instruct"])
def test_deepseek_inference_no_cache_cpu(variant):
model_name = f"deepseek-ai/{variant}"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
model.generation_config = GenerationConfig.from_pretrained(model_name)
model.generation_config.pad_token_id = model.generation_config.eos_token_id
model.generation_config.use_cache = False
framework_model = Wrapper(model)

# Prepare input sentence
messages = [
{
"role": "user",
"content": "what is the integral of x^2 from 0 to 2?\nPlease reason step by step, and put your final answer within \\boxed{}.",
}
]
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
max_new_tokens = 100
generated_tokens = input_ids
for i in range(max_new_tokens):
logits = framework_model(input_ids)
next_token_logits = logits[:, -1, :]
next_token_id = torch.argmax(next_token_logits, dim=-1)

if next_token_id == tokenizer.eos_token_id:
break

input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=-1)

# Generated text
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
print(generated_text)


def download_model_and_tokenizer(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
return model, tokenizer


@pytest.mark.parametrize("variant", ["deepseek-math-7b-instruct"])
def test_deepseek_inference(record_forge_property, variant):
# Build Module Name
module_name = build_module_name(
framework=Framework.PYTORCH, model="deepseek", variant=variant, task=Task.QA, source=Source.HUGGINGFACE
)

# Record Forge Property
record_forge_property("model_name", module_name)

model_name = f"deepseek-ai/{variant}"

model, tokenizer = download_model_and_tokenizer(model_name)
model.generation_config = GenerationConfig.from_pretrained(model_name)
model.generation_config.pad_token_id = model.generation_config.eos_token_id
model.generation_config.use_cache = False
framework_model = Wrapper(model)

# Prepare input sentence
messages = [
{
"role": "user",
"content": "what is the integral of x^2 from 0 to 2?\nPlease reason step by step, and put your final answer within \\boxed{}.",
}
]
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
compiled_model = forge.compile(framework_model, sample_inputs=[input_ids], module_name=module_name)
max_new_tokens = 1
generated_tokens = input_ids
for i in range(max_new_tokens):
logits = compiled_model(input_ids)
next_token_logits = logits[:, -1, :]
next_token_id = torch.argmax(next_token_logits, dim=-1)

if next_token_id == tokenizer.eos_token_id:
break

input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=-1)

# Generated text
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
print(generated_text)

0 comments on commit 1a22034

Please sign in to comment.