Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test for deepseek_vl #1136

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
forge/test/models/pytorch/multimodal/deepseek/image/training_pipelines.jpg filter=lfs diff=lfs merge=lfs -text
2 changes: 2 additions & 0 deletions env/core_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,5 @@ pytorch_forecasting==1.0.0
patool
openpyxl==3.1.5
GitPython==3.1.44
dotmap==1.3.30
einops==0.8.1
4 changes: 2 additions & 2 deletions env/linux_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ sacrebleu==2.1.0
sacremoses==0.0.53
seaborn
scikit-image==0.20.0 # For DenseNet 121 HF XRay model
segmentation_models_pytorch==0.3.2
segmentation_models_pytorch==0.4.0
sentencepiece==0.2.0
subword-nmt==0.3.8
tensorflow-hub==0.12.0
timm==0.6.12
timm==0.9.16
yolov5==7.0.9
# The CPU versions of torch and torch visions are used due to their size being
# several GB smaller which made a large impact on the performance of CI
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

import pytest

import forge

from test.models.pytorch.multimodal.deepseek.utils.load_model import (
generate_model_deepseek_vl_pytorch,
generation,
)
from test.models.utils import Framework, Source, Task, build_module_name


@pytest.mark.parametrize("variant", ["deepseek-ai/deepseek-vl-1.3b-base"])
def test_deepseek_vl_no_cache_cpu_pytorch(record_forge_property, variant):

framework_model, vl_gpt, tokenizer, inputs_embeds = generate_model_deepseek_vl_pytorch(variant)

answer = generation(
max_new_tokens=512, model=framework_model, inputs_embeds=inputs_embeds, tokenizer=tokenizer, vl_gpt=vl_gpt
)

print(answer)


@pytest.mark.parametrize("variant", ["deepseek-ai/deepseek-vl-1.3b-base"])
def test_deepseek_vl_pytorch(record_forge_property, variant):

# Build Module Name
module_name = build_module_name(
framework=Framework.PYTORCH, model="deepseek", variant=variant, task=Task.QA, source=Source.HUGGINGFACE
)

# Record Forge Property
record_forge_property("model_name", module_name)

framework_model, vl_gpt, tokenizer, inputs_embeds = generate_model_deepseek_vl_pytorch(variant)
compiled_model = forge.compile(framework_model, sample_inputs=[inputs_embeds], module_name=module_name)
answer = generation(
max_new_tokens=1, model=compiled_model, inputs_embeds=inputs_embeds, tokenizer=tokenizer, vl_gpt=vl_gpt
)

print(answer)
318 changes: 318 additions & 0 deletions forge/test/models/pytorch/multimodal/deepseek/utils/conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0


# Copyright (c) 2023-2024 DeepSeek.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

"""
From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
"""

import dataclasses
from enum import IntEnum, auto
from typing import Dict, List


class SeparatorStyle(IntEnum):
"""Separator styles."""

ADD_COLON_SINGLE = auto()
ADD_COLON_TWO = auto()
ADD_COLON_SPACE_SINGLE = auto()
NO_COLON_SINGLE = auto()
NO_COLON_TWO = auto()
ADD_NEW_LINE_SINGLE = auto()
LLAMA2 = auto()
CHATGLM = auto()
CHATML = auto()
CHATINTERN = auto()
DOLLY = auto()
RWKV = auto()
PHOENIX = auto()
ROBIN = auto()
DeepSeek = auto()
PLAIN = auto()
ALIGNMENT = auto()


@dataclasses.dataclass
class Conversation:
"""A class that manages prompt templates and keeps all conversation history."""

# The name of this template
name: str
# The template of the system prompt
system_template: str = "{system_message}"
# The system message
system_message: str = ""
# The names of two roles
roles: List[str] = (("USER", "ASSISTANT"),)
# All messages. Each item is (role, message).
messages: List[List[str]] = ()
# The number of few shot examples
offset: int = 0
# The separator style and configurations
sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
sep: str = "\n"
sep2: str = None
# Stop criteria (the default one is EOS token)
stop_str: str = None
# Stops generation if meeting any token in this list
stop_token_ids: List[int] = None

def get_prompt(self) -> str:
"""Get the prompt for generation."""
system_prompt = self.system_template.format(system_message=self.system_message)

if self.sep_style == SeparatorStyle.DeepSeek:
seps = [self.sep, self.sep2]
if system_prompt == "" or system_prompt is None:
ret = ""
else:
ret = system_prompt + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.LLAMA2:
seps = [self.sep, self.sep2]
if self.system_message:
ret = system_prompt
else:
ret = "[INST] "
for i, (role, message) in enumerate(self.messages):
tag = self.roles[i % 2]
if message:
if type(message) is tuple: # multimodal message
message, _ = message
if i == 0:
ret += message + " "
else:
ret += tag + " " + message + seps[i % 2]
else:
ret += tag
return ret
elif self.sep_style == SeparatorStyle.PLAIN:
seps = [self.sep, self.sep2]
ret = ""
for i, (role, message) in enumerate(self.messages):
if message:
if type(message) is tuple:
message, _, _ = message
if i % 2 == 0:
ret += message + seps[i % 2]
else:
ret += message + seps[i % 2]
else:
ret += ""
return ret
elif self.sep_style == SeparatorStyle.ALIGNMENT:
seps = [self.sep, self.sep2]
ret = ""
for i, (role, message) in enumerate(self.messages):
if message:
if type(message) is tuple:
message, _, _ = message
if i % 2 == 0:
ret += "<image>\n" + seps[i % 2]
else:
ret += message + seps[i % 2]
else:
ret += ""
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")

def get_prompt_for_current_round(self, content=None):
"""Get current round formatted question prompt during sft training"""
if self.sep_style == SeparatorStyle.PLAIN:
formatted_question = "<image>\n"
elif self.sep_style == SeparatorStyle.DeepSeek:
formatted_question = f"{self.roles[0]}: " + content.strip() + self.sep + f"{self.roles[1]}:"
else:
raise ValueError(f"Unsupported sep_style: {self.sep_style}")
return formatted_question

def set_system_message(self, system_message: str):
"""Set the system message."""
self.system_message = system_message

def append_message(self, role: str, message: str):
"""Append a new message."""
self.messages.append([role, message])

def reset_message(self):
"""Reset a new message."""
self.messages = []

def to_gradio_chatbot(self):
"""Convert the conversation to gradio chatbot format."""
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret

def to_openai_api_messages(self):
"""Convert the conversation to OpenAI chat completion format."""
system_prompt = self.system_template.format(system_message=self.system_message)
ret = [{"role": "system", "content": system_prompt}]

for i, (_, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append({"role": "user", "content": msg})
else:
if msg is not None:
ret.append({"role": "assistant", "content": msg})
return ret

def copy(self):
return Conversation(
name=self.name,
system_template=self.system_template,
system_message=self.system_message,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
stop_str=self.stop_str,
stop_token_ids=self.stop_token_ids,
)

def dict(self):
return {
"template_name": self.name,
"system_message": self.system_message,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
}


# A global registry for all conversation templates
conv_templates: Dict[str, Conversation] = {}


def register_conv_template(template: Conversation, override: bool = False):
"""Register a new conversation template."""
if not override:
assert template.name not in conv_templates, f"{template.name} has been registered."

conv_templates[template.name] = template


def get_conv_template(name: str) -> Conversation:
"""Get a conversation template."""
return conv_templates[name].copy()


# llava_llama2 template
register_conv_template(
Conversation(
name="llava_llama2",
system_message="You are a helpful language and vision assistant. "
"You are able to understand the visual content that the user provides, "
"and assist the user with a variety of tasks using natural language.",
system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
roles=("[INST]", "[/INST]"),
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA2,
sep=" ",
sep2=" </s><s>",
stop_token_ids=[2],
)
)

# llama2 template
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
register_conv_template(
Conversation(
name="llama-2",
system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
roles=("[INST]", "[/INST]"),
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA2,
sep=" ",
sep2=" </s><s>",
stop_token_ids=[2],
)
)


# deepseek template
register_conv_template(
Conversation(
name="deepseek",
system_template="{system_message}",
# system_message="You are a helpful assistant. Please answer truthfully and write out your "
# "thinking step by step to be sure you get the right answer.",
system_message="",
roles=("User", "Assistant"),
messages=(),
offset=0,
sep_style=SeparatorStyle.DeepSeek,
sep="\n\n",
sep2="<|end▁of▁sentence|>",
stop_token_ids=[100001],
stop_str=["User:", "<|end▁of▁sentence|>"],
)
)

register_conv_template(
Conversation(
name="plain",
system_template="",
system_message="",
roles=("", ""),
messages=(),
offset=0,
sep_style=SeparatorStyle.PLAIN,
sep="",
sep2="",
stop_token_ids=[2],
stop_str=["</s>"],
)
)


register_conv_template(
Conversation(
name="alignment",
system_template="",
system_message="",
roles=("", ""),
messages=(),
offset=0,
sep_style=SeparatorStyle.ALIGNMENT,
sep="",
sep2="",
stop_token_ids=[2],
stop_str=["</s>"],
)
)
Loading
Loading