Skip to content

Commit

Permalink
Organize image inputs (#1531)
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 authored Sep 29, 2024
1 parent e165a9f commit fd9ad81
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 132 deletions.
10 changes: 2 additions & 8 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,8 @@ class TokenizedGenerateReqInput:
input_text: str
# The input token ids
input_ids: List[int]
# The pixel values for input images
pixel_values: List[float]
# The hash values of input images
image_hashes: List[int]
# The image sizes
image_sizes: List[List[int]]
# The image input
image_inputs: dict
# The sampling parameters
sampling_params: SamplingParams
# Whether to return the logprobs
Expand All @@ -188,8 +184,6 @@ class TokenizedGenerateReqInput:
top_logprobs_num: int
# Whether to stream output
stream: bool
# Modalities of the input images
modalites: Optional[List[str]] = None

# LoRA related
lora_path: Optional[str] = None # None means just use the base model
Expand Down
51 changes: 37 additions & 14 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,39 @@ def to_json(self):
}


@dataclass
class ImageInputs:
pixel_values: torch.Tensor
image_hash: int
image_sizes: Optional[list] = None
image_offsets: Optional[list] = None
pad_values: Optional[list] = None
modalities: Optional[list] = None

image_embeds: Optional[List[torch.Tensor]] = None
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None

@staticmethod
def from_dict(obj, vocab_size):
# Use image hash as fake token_ids, which is then used for prefix matching
ret = ImageInputs(
pixel_values=obj["pixel_values"],
image_hash=hash(tuple(obj["image_hashes"])),
)
image_hash = ret.image_hash
ret.pad_values = [
(image_hash) % vocab_size,
(image_hash >> 16) % vocab_size,
(image_hash >> 32) % vocab_size,
(image_hash >> 64) % vocab_size,
]
ret.image_sizes = obj["image_sizes"]
# Only when pixel values is not None we have modalities
ret.modalities = obj["modalities"]
return ret


class Req:
"""Store all inforamtion of a request."""

Expand Down Expand Up @@ -147,11 +180,7 @@ def __init__(
self.completion_tokens_wo_jump_forward = 0

# For vision inputs
self.pixel_values = None
self.image_sizes = None
self.image_offsets = None
self.pad_value = None
self.modalities = None
self.image_inputs: Optional[ImageInputs] = None

# Prefix info
self.prefix_indices = []
Expand Down Expand Up @@ -654,15 +683,9 @@ def check_for_jump_forward(self, model_runner):
self.tree_cache.cache_finished_req(req, cur_all_ids)

# re-applying image padding
if req.pixel_values is not None:
(
req.origin_input_ids,
req.image_offsets,
) = model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded,
req.pad_value,
req.pixel_values,
req.image_sizes,
if req.image_inputs is not None:
req.origin_input_ids = model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded, req.image_inputs
)

jump_forward_reqs.append(req)
Expand Down
37 changes: 16 additions & 21 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,9 @@ async def _handle_single_request(
)

if self.is_generation:
pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
obj.image_data if not_use_index else obj.image_data[index]
image_inputs = await self._get_image_inputs(
obj, obj.image_data if not_use_index else obj.image_data[index]
)
modalities = obj.modalities
return_logprob = (
obj.return_logprob if not_use_index else obj.return_logprob[index]
)
Expand Down Expand Up @@ -248,10 +247,7 @@ async def _handle_single_request(

sampling_params = SamplingParams(**obj.sampling_params[0])
sampling_params.max_new_tokens = 0
pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
obj.image_data[0]
)
modalities = obj.modalities
image_inputs = await self._get_image_inputs(obj, obj.image_data[0])
return_logprob = obj.return_logprob[0]
logprob_start_len = obj.logprob_start_len[0]
top_logprobs_num = obj.top_logprobs_num[0]
Expand All @@ -262,15 +258,12 @@ async def _handle_single_request(
rid,
input_text,
input_ids,
pixel_values,
image_hashes,
image_sizes,
image_inputs,
sampling_params,
return_logprob,
logprob_start_len,
top_logprobs_num,
obj.stream,
modalities,
(
obj.lora_path[index]
if isinstance(obj.lora_path, list)
Expand Down Expand Up @@ -369,24 +362,20 @@ async def _handle_batch_request(
sampling_params = self._get_sampling_params(obj.sampling_params[index])

if self.is_generation:
pixel_values, image_hashes, image_sizes = (
await self._get_pixel_values(obj.image_data[index])
image_inputs = await self._get_image_inputs(
obj, obj.image_data[index]
)
modalities = obj.modalities

tokenized_obj = TokenizedGenerateReqInput(
rid,
input_text,
input_ids,
pixel_values,
image_hashes,
image_sizes,
image_inputs,
sampling_params,
obj.return_logprob[index],
obj.logprob_start_len[index],
obj.top_logprobs_num[index],
obj.stream,
modalities,
(
obj.lora_path[index]
if isinstance(obj.lora_path, list)
Expand Down Expand Up @@ -697,10 +686,11 @@ def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
)
return top_logprobs

async def _get_pixel_values(self, image_data: List[Union[str, bytes]]):
async def _get_image_inputs(self, obj, image_data: List[Union[str, bytes]]):
if not image_data:
return None, None, None
return None

# TODO: move this into a processor for each vision architecture
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
Expand Down Expand Up @@ -741,7 +731,12 @@ async def _get_pixel_values(self, image_data: List[Union[str, bytes]]):
else:
raise ValueError(f"Invalid image data: {image_data}")

return pixel_values, image_hashes, image_sizes
return {
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"image_sizes": image_sizes,
"modalities": obj.modalities,
}

async def _process_single_image(
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
Expand Down
32 changes: 10 additions & 22 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
BaseFinishReason,
ImageInputs,
Req,
ScheduleBatch,
)
Expand Down Expand Up @@ -340,29 +341,16 @@ def handle_generate_request(
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.tokenizer = self.tokenizer
req.sampling_params = recv_req.sampling_params
req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None:
# Use image hash as fake token_ids, which is then used
# for prefix matching
image_hash = hash(tuple(recv_req.image_hashes))
req.pad_value = [
(image_hash) % self.model_config.vocab_size,
(image_hash >> 16) % self.model_config.vocab_size,
(image_hash >> 32) % self.model_config.vocab_size,
(image_hash >> 64) % self.model_config.vocab_size,
]
req.image_sizes = recv_req.image_sizes
(
req.origin_input_ids,
req.image_offsets,
) = self.model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded,
req.pad_value,
req.pixel_values,
req.image_sizes,

# Image inputs
if recv_req.image_inputs is not None:
req.image_inputs = ImageInputs.from_dict(
recv_req.image_inputs, self.model_config.vocab_size
)
# Only when pixel values is not None we have modalities
req.modalities = recv_req.modalites
req.origin_input_ids = self.model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded, req.image_inputs
)

req.return_logprob = recv_req.return_logprob
req.top_logprobs_num = recv_req.top_logprobs_num
req.stream = recv_req.stream
Expand Down
13 changes: 3 additions & 10 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

if TYPE_CHECKING:
from sglang.srt.layers.attention_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.schedule_batch import ImageInputs, ScheduleBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner

Expand Down Expand Up @@ -84,17 +84,10 @@ class InputMetadata:
extend_logprob_start_lens_cpu: List[int] = None

# For multimodal
pixel_values: List[torch.Tensor] = None
image_sizes: List[List[List[int]]] = None
image_offsets: List[List[int]] = None
modalities: List[List[str]] = None
image_inputs: List[ImageInputs] = None

def init_multimuldal_info(self, batch: ScheduleBatch):
reqs = batch.reqs
self.pixel_values = [r.pixel_values for r in reqs]
self.image_sizes = [r.image_sizes for r in reqs]
self.image_offsets = [r.image_offsets for r in reqs]
self.modalities = [r.modalities for r in reqs]
self.image_inputs = [r.image_inputs for r in batch.reqs]

def compute_positions(self, batch: ScheduleBatch):
if self.forward_mode.is_decode():
Expand Down
15 changes: 1 addition & 14 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,23 +498,10 @@ def forward_extend(self, batch: ScheduleBatch):
get_embedding=True,
)

def forward_extend_multi_modal(self, batch: ScheduleBatch):
input_metadata = InputMetadata.from_schedule_batch(self, batch)
return self.model.forward(
batch.input_ids,
input_metadata.positions,
input_metadata,
input_metadata.pixel_values,
input_metadata.image_sizes,
input_metadata.image_offsets,
)

def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]:
assert batch.forward_mode is not None

if self.is_multimodal_model and batch.forward_mode.is_extend():
return self.forward_extend_multi_modal(batch)
elif batch.forward_mode.is_decode():
if batch.forward_mode.is_decode():
return self.forward_decode(batch)
elif batch.forward_mode.is_extend():
return self.forward_extend(batch)
Expand Down
Loading

0 comments on commit fd9ad81

Please sign in to comment.