Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Simplify io_struct and tokenizer_manager #1549

Merged
merged 4 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 49 additions & 59 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class GenerateReqInput:
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None
# The sampling_params. See descriptions below.
sampling_params: Union[List[Dict], Dict] = None
sampling_params: Optional[Union[List[Dict], Dict]] = None
# The request id.
rid: Optional[Union[List[str], str]] = None
# Whether to return logprobs.
Expand All @@ -55,28 +55,47 @@ class GenerateReqInput:
# LoRA related
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None

# Whether it is a single request or a batch request
is_single: bool = True

def post_init(self):
if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None
):
raise ValueError("Either text or input_ids should be provided.")

if (
isinstance(self.sampling_params, dict)
and self.sampling_params.get("n", 1) != 1
):
is_single = False
self.is_single = False
if self.text is not None:
if isinstance(self.text, str):
self.is_single = True
self.batch_size = 1
else:
self.batch_size = len(self.text)
else:
if self.text is not None:
is_single = isinstance(self.text, str)
if isinstance(self.input_ids[0], int):
self.is_single = True
self.batch_size = 1
else:
is_single = isinstance(self.input_ids[0], int)
self.is_single = is_single
self.batch_size = len(self.input_ids)

if self.sampling_params is None:
self.parallel_sample_num = 1
if isinstance(self.sampling_params, dict):
self.parallel_sample_num = self.sampling_params.get("n", 1)
else: # isinstance(self.sampling_params, list):
self.parallel_sample_num = self.sampling_params[0].get("n", 1)
for sp in self.sampling_params:
# TODO cope with the case that the parallel_sample_num is different for different samples
assert self.parallel_sample_num == sp.get(
"n", 1
), "The parallel_sample_num should be the same for all samples in sample params."

if self.parallel_sample_num > 1:
if self.is_single:
self.is_single = False
if self.text is not None:
self.text = [self.text]
if self.input_ids is not None:
self.input_ids = [self.input_ids]

if is_single:
if self.is_single:
if self.sampling_params is None:
self.sampling_params = {}
if self.rid is None:
Expand All @@ -88,79 +107,54 @@ def post_init(self):
if self.top_logprobs_num is None:
self.top_logprobs_num = 0
else:
parallel_sample_num_list = []
if isinstance(self.sampling_params, dict):
parallel_sample_num = self.sampling_params.get("n", 1)
elif isinstance(self.sampling_params, list):
for sp in self.sampling_params:
parallel_sample_num = sp.get("n", 1)
parallel_sample_num_list.append(parallel_sample_num)
parallel_sample_num = max(parallel_sample_num_list)
all_equal = all(
element == parallel_sample_num
for element in parallel_sample_num_list
)
if parallel_sample_num > 1 and (not all_equal):
# TODO cope with the case that the parallel_sample_num is different for different samples
raise ValueError(
"The parallel_sample_num should be the same for all samples in sample params."
)
if self.parallel_sample_num == 1:
num = self.batch_size
else:
parallel_sample_num = 1
self.parallel_sample_num = parallel_sample_num

if parallel_sample_num != 1:
# parallel sampling +1 represents the original prefill stage
num = parallel_sample_num + 1
if isinstance(self.text, list):
# suppot batch operation
self.batch_size = len(self.text)
num = num * len(self.text)
elif isinstance(self.input_ids, list) and isinstance(
self.input_ids[0], list
):
self.batch_size = len(self.input_ids)
num = num * len(self.input_ids)
else:
self.batch_size = 1
else:
# support select operation
num = len(self.text) if self.text is not None else len(self.input_ids)
self.batch_size = num
# FIXME support cascade inference
# first bs samples are used for caching the prefix for parallel sampling
num = self.batch_size + self.parallel_sample_num * self.batch_size

if self.image_data is None:
self.image_data = [None] * num
elif not isinstance(self.image_data, list):
self.image_data = [self.image_data] * num
elif isinstance(self.image_data, list):
# multi-image with n > 1
# FIXME incorrect order for duplication
self.image_data = self.image_data * num

if self.sampling_params is None:
self.sampling_params = [{}] * num
elif not isinstance(self.sampling_params, list):
self.sampling_params = [self.sampling_params] * num
else:
assert self.parallel_sample_num == 1

if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(num)]
else:
if not isinstance(self.rid, list):
raise ValueError("The rid should be a list.")
assert isinstance(self.rid, list), "The rid should be a list."
assert self.parallel_sample_num == 1

if self.return_logprob is None:
self.return_logprob = [False] * num
elif not isinstance(self.return_logprob, list):
self.return_logprob = [self.return_logprob] * num
else:
assert self.parallel_sample_num == 1

if self.logprob_start_len is None:
self.logprob_start_len = [-1] * num
elif not isinstance(self.logprob_start_len, list):
self.logprob_start_len = [self.logprob_start_len] * num
else:
assert self.parallel_sample_num == 1

if self.top_logprobs_num is None:
self.top_logprobs_num = [0] * num
elif not isinstance(self.top_logprobs_num, list):
self.top_logprobs_num = [self.top_logprobs_num] * num
else:
assert self.parallel_sample_num == 1


@dataclass
Expand Down Expand Up @@ -199,8 +193,6 @@ class EmbeddingReqInput:
# Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None

is_single: bool = True

def post_init(self):
if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None
Expand Down Expand Up @@ -255,8 +247,6 @@ class RewardReqInput:
# Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None

is_single: bool = True

def post_init(self):
self.is_single = isinstance(self.conv[0], dict)

Expand Down
Loading
Loading