Skip to content

Commit

Permalink
move multimodal infos
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 committed Aug 5, 2024
1 parent 3e67727 commit c164fb3
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 18 deletions.
15 changes: 1 addition & 14 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,6 @@ class ScheduleBatch:
return_logprob: bool = False
top_logprobs_nums: List[int] = None

# For multimodal
pixel_values: List[torch.Tensor] = None
image_sizes: List[List[int]] = None
image_offsets: List[int] = None

# Batched sampling params
temperatures: torch.Tensor = None
top_ps: torch.Tensor = None
Expand Down Expand Up @@ -447,19 +442,12 @@ def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor
] = out_cache_loc[pt : pt + extend_lens_cpu[i]]
pt += extend_lens_cpu[i]

# Image auxiliary
self.pixel_values = [r.pixel_values for r in reqs]
self.image_sizes = [r.image_size for r in reqs]
self.image_offsets = [
(r.image_offset - p_len) if r.image_offset is not None else 0
for r, p_len in zip(reqs, prefix_lens_cpu)
]

self.extend_num_tokens = extend_num_tokens
self.total_num_tokens = total_num_tokens
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.prefix_lens_cpu = prefix_lens_cpu
self.extend_lens_cpu = extend_lens_cpu
self.out_cache_loc = out_cache_loc

with torch.device("cuda"):
# Batched tensors
Expand All @@ -469,7 +457,6 @@ def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor
)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int32)
self.out_cache_loc = out_cache_loc

self.batch_sampling_params(vocab_size, int_token_logit_bias)

Expand Down
16 changes: 15 additions & 1 deletion python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ class InputMetadata:
extend_start_loc: torch.Tensor = None
extend_no_prefix: bool = None

# For multimodal
pixel_values: List[torch.Tensor] = None
image_sizes: List[List[int]] = None
image_offsets: List[int] = None

# Output options
return_logprob: bool = False
top_logprobs_nums: List[int] = None
Expand All @@ -62,6 +67,15 @@ class InputMetadata:
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
flashinfer_use_ragged: bool = False

def init_multimodal_infos(self, batch: ScheduleBatch):
reqs = batch.reqs
self.pixel_values = [r.pixel_values for r in reqs]
self.image_sizes = [r.image_size for r in reqs]
self.image_offsets = [
(r.image_offset - p_len) if r.image_offset is not None else 0
for r, p_len in zip(reqs, batch.prefix_lens_cpu)
]

def compute_positions(self, batch: ScheduleBatch):
bs = self.batch_size
if self.forward_mode == ForwardMode.DECODE:
Expand Down Expand Up @@ -96,7 +110,6 @@ def compute_extend_infos(self, batch: ScheduleBatch):

@classmethod
def from_batch(cls, model_runner, batch: ScheduleBatch, forward_mode: ForwardMode):

ret = cls(
forward_mode=forward_mode,
batch_size=batch.batch_size(),
Expand All @@ -113,6 +126,7 @@ def from_batch(cls, model_runner, batch: ScheduleBatch, forward_mode: ForwardMod
ret.compute_positions(batch)

if forward_mode != ForwardMode.DECODE:
ret.init_multimodal_infos(batch)
ret.compute_extend_infos(batch)

prefix_lens = (
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,9 @@ def forward_extend_multi_modal(self, batch: ScheduleBatch):
batch.input_ids,
input_metadata.positions,
input_metadata,
batch.pixel_values,
batch.image_sizes,
batch.image_offsets,
input_metadata.pixel_values,
input_metadata.image_sizes,
input_metadata.image_offsets,
)

def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
Expand Down

0 comments on commit c164fb3

Please sign in to comment.