Skip to content

Commit

Permalink
[Refactor] Simplify io_struct and tokenizer_manager (#1549)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying1123 authored Oct 1, 2024
1 parent 100f5b8 commit f202ed9
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 167 deletions.
108 changes: 49 additions & 59 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class GenerateReqInput:
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None
# The sampling_params. See descriptions below.
sampling_params: Union[List[Dict], Dict] = None
sampling_params: Optional[Union[List[Dict], Dict]] = None
# The request id.
rid: Optional[Union[List[str], str]] = None
# Whether to return logprobs.
Expand All @@ -55,28 +55,47 @@ class GenerateReqInput:
# LoRA related
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None

# Whether it is a single request or a batch request
is_single: bool = True

def post_init(self):
if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None
):
raise ValueError("Either text or input_ids should be provided.")

if (
isinstance(self.sampling_params, dict)
and self.sampling_params.get("n", 1) != 1
):
is_single = False
self.is_single = False
if self.text is not None:
if isinstance(self.text, str):
self.is_single = True
self.batch_size = 1
else:
self.batch_size = len(self.text)
else:
if self.text is not None:
is_single = isinstance(self.text, str)
if isinstance(self.input_ids[0], int):
self.is_single = True
self.batch_size = 1
else:
is_single = isinstance(self.input_ids[0], int)
self.is_single = is_single
self.batch_size = len(self.input_ids)

if self.sampling_params is None:
self.parallel_sample_num = 1
if isinstance(self.sampling_params, dict):
self.parallel_sample_num = self.sampling_params.get("n", 1)
else: # isinstance(self.sampling_params, list):
self.parallel_sample_num = self.sampling_params[0].get("n", 1)
for sp in self.sampling_params:
# TODO cope with the case that the parallel_sample_num is different for different samples
assert self.parallel_sample_num == sp.get(
"n", 1
), "The parallel_sample_num should be the same for all samples in sample params."

if self.parallel_sample_num > 1:
if self.is_single:
self.is_single = False
if self.text is not None:
self.text = [self.text]
if self.input_ids is not None:
self.input_ids = [self.input_ids]

if is_single:
if self.is_single:
if self.sampling_params is None:
self.sampling_params = {}
if self.rid is None:
Expand All @@ -88,79 +107,54 @@ def post_init(self):
if self.top_logprobs_num is None:
self.top_logprobs_num = 0
else:
parallel_sample_num_list = []
if isinstance(self.sampling_params, dict):
parallel_sample_num = self.sampling_params.get("n", 1)
elif isinstance(self.sampling_params, list):
for sp in self.sampling_params:
parallel_sample_num = sp.get("n", 1)
parallel_sample_num_list.append(parallel_sample_num)
parallel_sample_num = max(parallel_sample_num_list)
all_equal = all(
element == parallel_sample_num
for element in parallel_sample_num_list
)
if parallel_sample_num > 1 and (not all_equal):
# TODO cope with the case that the parallel_sample_num is different for different samples
raise ValueError(
"The parallel_sample_num should be the same for all samples in sample params."
)
if self.parallel_sample_num == 1:
num = self.batch_size
else:
parallel_sample_num = 1
self.parallel_sample_num = parallel_sample_num

if parallel_sample_num != 1:
# parallel sampling +1 represents the original prefill stage
num = parallel_sample_num + 1
if isinstance(self.text, list):
# suppot batch operation
self.batch_size = len(self.text)
num = num * len(self.text)
elif isinstance(self.input_ids, list) and isinstance(
self.input_ids[0], list
):
self.batch_size = len(self.input_ids)
num = num * len(self.input_ids)
else:
self.batch_size = 1
else:
# support select operation
num = len(self.text) if self.text is not None else len(self.input_ids)
self.batch_size = num
# FIXME support cascade inference
# first bs samples are used for caching the prefix for parallel sampling
num = self.batch_size + self.parallel_sample_num * self.batch_size

if self.image_data is None:
self.image_data = [None] * num
elif not isinstance(self.image_data, list):
self.image_data = [self.image_data] * num
elif isinstance(self.image_data, list):
# multi-image with n > 1
# FIXME incorrect order for duplication
self.image_data = self.image_data * num

if self.sampling_params is None:
self.sampling_params = [{}] * num
elif not isinstance(self.sampling_params, list):
self.sampling_params = [self.sampling_params] * num
else:
assert self.parallel_sample_num == 1

if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(num)]
else:
if not isinstance(self.rid, list):
raise ValueError("The rid should be a list.")
assert isinstance(self.rid, list), "The rid should be a list."
assert self.parallel_sample_num == 1

if self.return_logprob is None:
self.return_logprob = [False] * num
elif not isinstance(self.return_logprob, list):
self.return_logprob = [self.return_logprob] * num
else:
assert self.parallel_sample_num == 1

if self.logprob_start_len is None:
self.logprob_start_len = [-1] * num
elif not isinstance(self.logprob_start_len, list):
self.logprob_start_len = [self.logprob_start_len] * num
else:
assert self.parallel_sample_num == 1

if self.top_logprobs_num is None:
self.top_logprobs_num = [0] * num
elif not isinstance(self.top_logprobs_num, list):
self.top_logprobs_num = [self.top_logprobs_num] * num
else:
assert self.parallel_sample_num == 1


@dataclass
Expand Down Expand Up @@ -199,8 +193,6 @@ class EmbeddingReqInput:
# Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None

is_single: bool = True

def post_init(self):
if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None
Expand Down Expand Up @@ -255,8 +247,6 @@ class RewardReqInput:
# Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None

is_single: bool = True

def post_init(self):
self.is_single = isinstance(self.conv[0], dict)

Expand Down
Loading

0 comments on commit f202ed9

Please sign in to comment.