Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input_embeds support #2052

Merged
merged 15 commits into from
Nov 26, 2024
10 changes: 6 additions & 4 deletions docs/references/sampling_params.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,23 @@ The `/generate` endpoint accepts the following arguments in the JSON format.
class GenerateReqInput:
# The input prompt. It can be a single prompt or a batch of prompts.
text: Optional[Union[List[str], str]] = None
# The token ids for text; one can either specify text or input_ids.
# The token ids for text; one can specify either text or input_ids
input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None
# The sampling_params. See descriptions below.
sampling_params: Union[List[Dict], Dict] = None
sampling_params: Optional[Union[List[Dict], Dict]] = None
# The request id.
rid: Optional[Union[List[str], str]] = None
# Whether to return logprobs.
return_logprob: Optional[Union[List[bool], bool]] = None
# The start location of the prompt for return_logprob.
# If return logprobs, the start location in the prompt for returning logprobs.
# By default, this value is "-1", which means it will only return logprobs for output tokens.
logprob_start_len: Optional[Union[List[int], int]] = None
# The number of top logprobs to return.
# If return logprobs, the number of top logprobs to return at each position.
top_logprobs_num: Optional[Union[List[int], int]] = None
# Whether to detokenize tokens in text in the returned logprobs.
return_text_in_logprobs: bool = False
Expand Down
30 changes: 25 additions & 5 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
class GenerateReqInput:
# The input prompt. It can be a single prompt or a batch of prompts.
text: Optional[Union[List[str], str]] = None
# The token ids for text; one can either specify text or input_ids.
# The token ids for text; one can specify either text or input_ids
input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None
Expand Down Expand Up @@ -59,10 +61,16 @@ class GenerateReqInput:
session_rid: Optional[Union[List[str], str]] = None

def normalize_batch_and_arguments(self):
if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None
if (
self.text is None and self.input_ids is None and self.input_embeds is None
) or (
self.text is not None
and self.input_ids is not None
and self.input_embeds is not None
):
raise ValueError("Either text or input_ids should be provided.")
raise ValueError(
"Either text, input_ids or input_embeds should be provided."
)

# Derive the batch size
if self.text is not None:
Expand All @@ -72,13 +80,21 @@ def normalize_batch_and_arguments(self):
else:
self.is_single = False
self.batch_size = len(self.text)
else:
self.input_embeds = None
elif self.input_ids is not None:
if isinstance(self.input_ids[0], int):
self.is_single = True
self.batch_size = 1
else:
self.is_single = False
self.batch_size = len(self.input_ids)
self.input_embeds = None
else:
if isinstance(self.input_embeds[0][0], float):
self.is_single = True
self.batch_size = 1
else:
self.batch_size = len(self.input_embeds)

# Handle parallel sampling
# When parallel sampling is used, we always treat the input as a batch.
Expand Down Expand Up @@ -201,6 +217,8 @@ class TokenizedGenerateReqInput:

# LoRA related
lora_path: Optional[str] = None # None means just use the base model
# The input embeds
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None

# Session id info for continual prompting
session_id: Optional[int] = None
Expand All @@ -217,6 +235,8 @@ class EmbeddingReqInput:
rid: Optional[Union[List[str], str]] = None
# Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None
# Dummy input embeds for compatibility
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None

def normalize_batch_and_arguments(self):
if (self.text is None and self.input_ids is None) or (
Expand Down
21 changes: 21 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def __init__(
origin_input_ids: Tuple[int],
sampling_params: SamplingParams,
lora_path: Optional[str] = None,
input_embeds: Optional[List[List[float]]] = None,
session_id: Optional[str] = None,
):
# Input and output info
Expand All @@ -191,6 +192,7 @@ def __init__(

self.sampling_params = sampling_params
self.lora_path = lora_path
self.input_embeds = input_embeds

# Memory pool info
self.req_pool_idx = None
Expand Down Expand Up @@ -448,6 +450,7 @@ class ScheduleBatch:

# Batched arguments to model runner
input_ids: torch.Tensor = None
input_embeds: torch.Tensor = None
req_pool_indices: torch.Tensor = None
seq_lens: torch.Tensor = None
# The output locations of the KV cache
Expand Down Expand Up @@ -631,6 +634,9 @@ def prepare_for_extend(self):
req_pool_indices = self.alloc_req_slots(bs)
out_cache_loc = self.alloc_token_slots(extend_num_tokens)

input_embeds = []

pt = 0
for i, req in enumerate(reqs):
already_computed = (
req.extend_logprob_start_len + 1 + req.cached_tokens
Expand All @@ -649,6 +655,11 @@ def prepare_for_extend(self):
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
)

# If input_embeds are available, store them
if req.input_embeds is not None:
# If req.input_embeds is already a list, append its content directly
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting

# Compute the relative logprob_start_len in an extend batch
if req.logprob_start_len >= pre_len:
extend_logprob_start_len = min(
Expand All @@ -671,6 +682,12 @@ def prepare_for_extend(self):
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.input_embeds = (
torch.tensor(input_embeds).to(self.device, non_blocking=True)
if input_embeds
else None
)

self.out_cache_loc = out_cache_loc

self.seq_lens_sum = sum(seq_lens)
Expand Down Expand Up @@ -1053,6 +1070,7 @@ def get_model_worker_batch(self):
encoder_out_cache_loc=self.encoder_out_cache_loc,
lora_paths=[req.lora_path for req in self.reqs],
sampling_info=self.sampling_info,
input_embeds=self.input_embeds,
)

def copy(self):
Expand Down Expand Up @@ -1123,6 +1141,9 @@ class ModelWorkerBatch:
# Sampling info
sampling_info: SamplingBatchInfo

# The input Embeds
input_embeds: Optional[torch.tensor] = None


@triton.jit
def write_req_to_token_pool_triton(
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,12 +525,20 @@ def handle_generate_request(
recv_req: TokenizedGenerateReqInput,
):
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
# Check if input_embeds is present and create dummy input_ids
if recv_req.input_embeds is not None:
# Generate fake input_ids based on the length of input_embeds
seq_length = len(recv_req.input_embeds)
fake_input_ids = [1] * seq_length
recv_req.input_ids = fake_input_ids

req = Req(
recv_req.rid,
recv_req.input_text,
recv_req.input_ids,
recv_req.sampling_params,
lora_path=recv_req.lora_path,
input_embeds=recv_req.input_embeds,
)
req.tokenizer = self.tokenizer
if recv_req.session_id is not None:
Expand Down
17 changes: 14 additions & 3 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,18 @@ async def _tokenize_one_request(
):
"""Tokenize one request."""
# Tokenize
input_embeds = None
input_text = obj.text
if obj.input_ids is None:
if obj.input_embeds is not None:
if not self.server_args.disable_radix_cache:
raise ValueError(
"input_embeds is provided while disable_radix_cache is False. "
"Please add `--disable-radix-cach` when you launch the server "
"if you want to use input_embeds as inputs."
)
input_embeds = obj.input_embeds
input_ids = obj.input_ids
merrymercy marked this conversation as resolved.
Show resolved Hide resolved
elif obj.input_ids is None:
input_ids = self.tokenizer.encode(input_text)
else:
input_ids = obj.input_ids
Expand All @@ -219,7 +229,7 @@ async def _tokenize_one_request(
session_id = obj.session_id
session_rid = obj.session_rid

if len(input_ids) >= self.context_len:
if obj.input_ids is not None and len(input_ids) >= self.context_len:
raise ValueError(
f"The input ({len(input_ids)} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)."
Expand All @@ -242,7 +252,8 @@ async def _tokenize_one_request(
logprob_start_len,
top_logprobs_num,
obj.stream,
obj.lora_path,
lora_path=obj.lora_path,
input_embeds=input_embeds,
session_id=session_id,
session_rid=session_rid,
)
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ class ForwardBatch:
# For LoRA
lora_paths: Optional[List[str]] = None

# For input embeddings
input_embeds: Optional[torch.tensor] = None

# Sampling info
sampling_info: SamplingBatchInfo = None

Expand Down Expand Up @@ -231,6 +234,7 @@ def init_new(
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info,
input_embeds=batch.input_embeds,
)

if ret.global_num_tokens is not None:
Expand Down
14 changes: 11 additions & 3 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,9 +606,17 @@ def forward_decode(self, forward_batch: ForwardBatch):
def forward_extend(self, forward_batch: ForwardBatch):
self.attn_backend.init_forward_metadata(forward_batch)
if self.is_generation:
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
if forward_batch.input_embeds is None:
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
else:
return self.model.forward(
forward_batch.input_ids,
forward_batch.positions,
forward_batch,
input_embeds=forward_batch.input_embeds.bfloat16(),
)
else:
# Only embedding models have get_embedding parameter
return self.model.forward(
Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"test_double_sparsity.py",
"test_embedding_openai_server.py",
"test_eval_accuracy_mini.py",
"test_input_embeddings.py",
"test_json_constrained.py",
"test_large_max_new_tokens.py",
"test_metrics.py",
Expand Down
114 changes: 114 additions & 0 deletions test/srt/test_input_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import json
import unittest

import requests
from transformers import AutoModelForCausalLM, AutoTokenizer

from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)


class TestInputEmbeds(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model)
cls.ref_model = AutoModelForCausalLM.from_pretrained(cls.model)
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--disable-radix"],
)
cls.texts = [
"The capital of France is",
"What is the best time of year to visit Japan for cherry blossoms?",
]

def generate_input_embeddings(self, text):
"""Generate input embeddings for a given text."""
input_ids = self.tokenizer(text, return_tensors="pt")["input_ids"]
embeddings = self.ref_model.get_input_embeddings()(input_ids)
return embeddings.squeeze().tolist() # Convert tensor to a list for API use

def send_request(self, payload):
"""Send a POST request to the API and return the response."""
response = requests.post(
self.base_url + "/generate",
json=payload,
timeout=30, # Set a reasonable timeout for the API request
)
if response.status_code == 200:
return response.json()
return {
"error": f"Request failed with status {response.status_code}: {response.text}"
}

def test_text_based_response(self):
"""Print API response using text-based input."""
for text in self.texts:
payload = {
"model": self.model,
"text": text,
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
}
response = self.send_request(payload)
print(
f"Text Input: {text}\nResponse: {json.dumps(response, indent=2)}\n{'-' * 80}"
)

def test_embedding_based_response(self):
"""Print API response using input embeddings."""
for text in self.texts:
embeddings = self.generate_input_embeddings(text)
payload = {
"model": self.model,
"input_embeds": embeddings,
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
}
response = self.send_request(payload)
print(
f"Embeddings Input (for text '{text}'):\nResponse: {json.dumps(response, indent=2)}\n{'-' * 80}"
)

def test_compare_text_vs_embedding(self):
"""Print responses for both text-based and embedding-based inputs."""
for text in self.texts:
# Text-based payload
text_payload = {
"model": self.model,
"text": text,
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
}
# Embedding-based payload
embeddings = self.generate_input_embeddings(text)
embed_payload = {
"model": self.model,
"input_embeds": embeddings,
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
}
# Get responses
text_response = self.send_request(text_payload)
embed_response = self.send_request(embed_payload)
# Print responses
print(
f"Text Input: {text}\nText-Based Response: {json.dumps(text_response, indent=2)}\n"
)
print(
f"Embeddings Input (for text '{text}'):\nEmbedding-Based Response: {json.dumps(embed_response, indent=2)}\n{'-' * 80}"
)
self.assertEqual(text_response["text"], embed_response["text"])

@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True)


if __name__ == "__main__":
unittest.main()
Loading