Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Molmo-D-7B Model #1542

Closed
wants to merge 17 commits into from
21 changes: 18 additions & 3 deletions python/sglang/srt/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class Conversation:
stop_str: Union[str, List[str]] = None
image_data: Optional[List[str]] = None
modalities: Optional[List[str]] = None
image_token_str: Optional[str] = None

def get_prompt(self) -> str:
"""Get the prompt for generation."""
Expand Down Expand Up @@ -334,6 +335,7 @@ def copy(self):
sep=self.sep,
sep2=self.sep2,
stop_str=self.stop_str,
image_token_str=self.image_token_str,
)

def dict(self):
Expand Down Expand Up @@ -381,6 +383,7 @@ def generate_chat_conv(
stop_str=conv.stop_str,
image_data=[],
modalities=[],
image_token_str=conv.image_token_str,
)

if isinstance(request.messages, str):
Expand Down Expand Up @@ -412,16 +415,15 @@ def generate_chat_conv(
num_image_url += 1
conv.modalities.append(content.modalities)
if num_image_url > 1:
image_token = "<image>"
image_token = conv.image_token_str
else:
image_token = "<image>\n"
image_token = conv.image_token_str + "\n"
for content in message.content:
if content.type == "text":
if num_image_url > 16:
real_content += "\n" # for video
real_content += content.text
elif content.type == "image_url":
# NOTE: Only works for llava
real_content += image_token
conv.append_image(content.image_url.url)
conv.append_message(conv.roles[0], real_content)
Expand Down Expand Up @@ -485,6 +487,7 @@ def generate_chat_conv(
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
stop_str=["<|endoftext|>", "<|im_end|>"],
image_token_str="<image>",
)
)

Expand All @@ -509,6 +512,7 @@ def generate_chat_conv(
sep_style=SeparatorStyle.LLAMA3,
sep="",
stop_str=["<|end_of_text|>", "<|eot_id|>"],
image_token_str="<image>",
)
)
# Reference: https://github.com/InternLM/lmdeploy/blob/387bf54b4f124e72aab30ae9755f562e435d3d01/lmdeploy/model.py#L425-L442
Expand All @@ -521,3 +525,14 @@ def generate_chat_conv(
stop_str=["<|im_end|>", "<|action_end|>"],
)
)

register_conv_template(
Conversation(
name="molmo",
system_template="",
roles=("User", "Assistant"),
sep=" ",
stop_str=["<|endoftext|>"],
image_token_str="",
)
)
139 changes: 138 additions & 1 deletion python/sglang/srt/managers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import List, Optional, Union

import numpy as np
import torch
import transformers

from sglang.srt.hf_transformers_utils import get_processor
Expand Down Expand Up @@ -177,10 +178,146 @@ async def process_images_async(
}


class MolmoImageProcessor(BaseImageProcessor):
SPECIAL_TOKEN_TO_ID = {
"<im_patch>": 152066,
"<im_start>": 152064,
"<im_end>": 152065,
"<im_col>": 152067,
"<|image|>": 152068,
}

def __init__(self, hf_config, server_args, _image_processor):
self.hf_config = hf_config
self._image_processor = _image_processor
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
initargs=(server_args,),
max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
)
self.image_patch_token_id = self.SPECIAL_TOKEN_TO_ID["<im_patch>"]
self.image_start_token_id = self.SPECIAL_TOKEN_TO_ID["<im_start>"]
self.image_end_token_id = self.SPECIAL_TOKEN_TO_ID["<im_end>"]
self.image_col_token_id = self.SPECIAL_TOKEN_TO_ID["<im_col>"]
self.image_prompt_token_id = self.SPECIAL_TOKEN_TO_ID["<|image|>"]

@staticmethod
def _process_image_task(
image_data_list: List[Union[str, bytes]],
input_ids: List[int],
image_patch_token_id: int,
image_start_token_id: int,
image_end_token_id: int,
image_col_token_id: int,
):
global global_processor

# Adapted from https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/preprocessing_molmo.py
# Returns:
# input_ids
# image_input_idx
# images
# image_masks
images = []
image_sizes = []
image_hashes = []
for image_data in image_data_list:
image, image_size = load_image(image_data)
image = image.convert("RGB")
image_hashes.append(hash(image_data))
images.append(np.array(image))
image_sizes.append(image_size)
hf_dict = global_processor.image_processor.multimodal_preprocess(
images=images,
image_idx=[-1] * len(images),
tokens=np.asarray(input_ids).astype(np.int32),
sequence_length=len(input_ids),
image_patch_token_id=image_patch_token_id,
image_col_token_id=image_col_token_id,
image_start_token_id=image_start_token_id,
image_end_token_id=image_end_token_id,
)

bos = (
global_processor.tokenizer.bos_token_id
or global_processor.tokenizer.eos_token_id
)
decoder_input_tokens = np.pad(
hf_dict["input_ids"], [[1, 0]], constant_values=bos
)
hf_dict["input_ids"] = decoder_input_tokens
if "image_input_idx" in hf_dict:
# Shift patch mapping up by one since we added BOS
image_input_idx = hf_dict["image_input_idx"]
hf_dict["image_input_idx"] = np.where(
image_input_idx < 0, image_input_idx, image_input_idx + 1
)

for k, v in hf_dict.items():
hf_dict[k] = torch.from_numpy(v)

hf_dict["image_hashes"] = image_hashes
hf_dict["pixel_values"] = hf_dict["images"]
hf_dict["image_sizes"] = image_sizes

del hf_dict["images"]

return hf_dict

async def _process_image(
self, image_data_list: List[Union[bytes, str]], input_ids: List[int]
):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
MolmoImageProcessor._process_image_task,
image_data_list,
input_ids,
self.image_patch_token_id,
self.image_start_token_id,
self.image_end_token_id,
self.image_col_token_id,
)
else:
return self._process_image_task(
image_data_list,
input_ids,
self.image_patch_token_id,
self.image_start_token_id,
self.image_end_token_id,
self.image_col_token_id,
)

async def process_images_async(self, image_data, request_obj, **kwargs):
if not image_data:
return None

input_ids = request_obj.input_ids
res = {}
if isinstance(image_data, list) and len(image_data) > 0:
# Multiple images
if len(image_data) > 1:
res = await self._process_image(image_data, input_ids)
else:
res = await self._process_image(image_data[0:1], input_ids)
elif isinstance(image_data, str):
# A single image
res = await self._process_image([image_data], input_ids)
else:
raise ValueError(f"Invalid image data: {image_data}")

res["modalities"] = request_obj.modalities
return res


def get_image_processor(
hf_config, server_args: ServerArgs, _image_processor
) -> BaseImageProcessor:
return LlavaImageProcessor(hf_config, server_args, _image_processor)
if "MolmoForCausalLM" in hf_config.architectures:
return MolmoImageProcessor(hf_config, server_args, _processor)
return LlavaImageProcessor(hf_config, server_args, _processor)


def get_dummy_image_processor():
Expand Down
13 changes: 13 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ class ImageInputs:
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None

# For Molmo
input_ids: Optional[List[torch.Tensor]] = None
image_input_idx: Optional[List[torch.Tensor]] = None
image_masks: Optional[List[torch.Tensor]] = None

@staticmethod
def from_dict(obj, vocab_size):
# Use image hash as fake token_ids, which is then used for prefix matching
Expand All @@ -145,6 +150,14 @@ def from_dict(obj, vocab_size):
ret.image_sizes = obj["image_sizes"]
# Only when pixel values is not None we have modalities
ret.modalities = obj["modalities"] or ["image"]

# For Molmo
if "input_ids" in obj:
ret.input_ids = obj["input_ids"]
if "image_input_idx" in obj:
ret.image_input_idx = obj["image_input_idx"]
if "image_masks" in obj:
ret.image_masks = obj["image_masks"]
return ret


Expand Down
Loading
Loading