Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test for deepseek_math #1148

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

meenakshiramanathan1
Copy link
Contributor

No description provided.

Copy link

TestsPassed ✅Skipped ⚠️Failed
TT-Forge-FE Tests564 ran485 passed79 skipped0 failed
TestResult
No test annotations available

Copy link

TestsPassed ✅Skipped ⚠️Failed
TT-Forge-FE Tests505 ran446 passed59 skipped0 failed
TestResult
No test annotations available

Copy link

github-actions bot commented Feb 3, 2025

TestsPassed ✅Skipped ⚠️Failed
TT-Forge-FE Tests505 ran446 passed59 skipped0 failed
TestResult
No test annotations available

1 similar comment
Copy link

github-actions bot commented Feb 3, 2025

TestsPassed ✅Skipped ⚠️Failed
TT-Forge-FE Tests505 ran446 passed59 skipped0 failed
TestResult
No test annotations available

Copy link

github-actions bot commented Feb 3, 2025

TestsPassed ✅Skipped ⚠️Failed
TT-Forge-FE Tests564 ran485 passed79 skipped0 failed
TestResult
No test annotations available

1 similar comment
Copy link

github-actions bot commented Feb 3, 2025

TestsPassed ✅Skipped ⚠️Failed
TT-Forge-FE Tests564 ran485 passed79 skipped0 failed
TestResult
No test annotations available

@meenakshiramanathan1 meenakshiramanathan1 force-pushed the mramanathan/deepseek_math branch 2 times, most recently from 1a22034 to 4af6e34 Compare February 4, 2025 07:04
@meenakshiramanathan1 meenakshiramanathan1 marked this pull request as ready for review February 4, 2025 07:05
Copy link

github-actions bot commented Feb 4, 2025

TestsPassed ✅Skipped ⚠️Failed
TT-Forge-FE Tests510 ran451 passed59 skipped0 failed
TestResult
No test annotations available

1 similar comment
Copy link

github-actions bot commented Feb 4, 2025

TestsPassed ✅Skipped ⚠️Failed
TT-Forge-FE Tests510 ran451 passed59 skipped0 failed
TestResult
No test annotations available

Copy link

github-actions bot commented Feb 4, 2025

TestsPassed ✅Skipped ⚠️Failed
TT-Forge-FE Tests568 ran489 passed79 skipped0 failed
TestResult
No test annotations available

1 similar comment
Copy link

github-actions bot commented Feb 4, 2025

TestsPassed ✅Skipped ⚠️Failed
TT-Forge-FE Tests568 ran489 passed79 skipped0 failed
TestResult
No test annotations available

Copy link

github-actions bot commented Feb 4, 2025

TestsPassed ✅Skipped ⚠️Failed
TT-Forge-FE Tests510 ran451 passed59 skipped0 failed
TestResult
No test annotations available

Copy link

github-actions bot commented Feb 4, 2025

TestsPassed ✅Skipped ⚠️Failed
TT-Forge-FE Tests568 ran489 passed79 skipped0 failed
TestResult
No test annotations available

Copy link

github-actions bot commented Feb 4, 2025

TestsPassed ✅Skipped ⚠️Failed
TT-Forge-FE Tests510 ran451 passed59 skipped0 failed
TestResult
No test annotations available

Copy link

github-actions bot commented Feb 4, 2025

TestsPassed ✅Skipped ⚠️Failed
TT-Forge-FE Tests568 ran489 passed79 skipped0 failed
TestResult
No test annotations available

Comment on lines +13 to +43
def generation(max_new_tokens, compiled_model, input_ids, tokenizer):
for i in range(max_new_tokens):
logits = compiled_model(input_ids)
next_token_logits = logits[:, -1, :]
next_token_id = torch.argmax(next_token_logits, dim=-1)

if next_token_id == tokenizer.eos_token_id:
break

input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=-1)

generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
return generated_text


def download_model_and_tokenizer(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="cpu")
model.generation_config = GenerationConfig.from_pretrained(model_name)
model.generation_config.pad_token_id = model.generation_config.eos_token_id
model.generation_config.use_cache = False

# Prepare input sentence
messages = [
{
"role": "user",
"content": "what is the integral of x^2 from 0 to 2?\nPlease reason step by step, and put your final answer within \\boxed{}.",
}
]
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
return model, tokenizer, input_ids
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just move this to utils folder.

next_token_logits = logits[:, -1, :]
next_token_id = torch.argmax(next_token_logits, dim=-1)

if next_token_id == tokenizer.eos_token_id:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing a tensor to an integer can work when the tensor has a single element, but it’s clearer and safer to extract the scalar value :))
Something like this should work:
next_token_id.item() == tokenizer.eos_token_id:

return generated_text


def download_model_and_tokenizer(model_name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One suggestion—would it be possible to add an option to set use_cache to True? We might consider an approach similar to what’s shown here. Also, as @vkovinicTT mentioned, this function might fit better in the utils folder :))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants