Skip to content

Commit

Permalink
Merge pull request #1654 from h2oai/fix_sglang_asyncio
Browse files Browse the repository at this point in the history
Fix asyncio sglang use
  • Loading branch information
pseudotensor authored May 30, 2024
2 parents 841ae0b + 4d1bf1a commit c7f8b8f
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 30 deletions.
90 changes: 61 additions & 29 deletions src/gpt_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
does_support_json_mode, claude3imagetag, gpt4imagetag, geminiimagetag, \
geminiimage_num_max, claude3image_num_max, gpt4image_num_max, llava_num_max, summary_prefix, extract_prefix, \
noop_prompt_type, unknown_prompt_type, template_prompt_type, none, claude3_image_tokens, gemini_image_tokens, \
gpt4_image_tokens, user_prompt_for_fake_system_prompt0
gpt4_image_tokens, user_prompt_for_fake_system_prompt0, empty_prompt_type
from evaluate_params import gen_hyper, gen_hyper0
from gen import SEED, get_limited_prompt, get_relaxed_max_new_tokens, get_model_retry, gradio_to_llm, \
get_client_from_inference_server
Expand Down Expand Up @@ -1520,6 +1520,13 @@ class SGlangInference(AGenerateStreamFirst, H2Oagenerate, LLM):
prompts: Any = []
count_output_tokens: Any = 0

# runtime
assistant_role: Any = None
user_role: Any = None
conv_template_before_prompt: Any = None
url: Any = None
pload: Any = None

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that python package exists in environment."""
Expand Down Expand Up @@ -1550,16 +1557,24 @@ def get_conv_template(conv_template_name):
conv_template = copy.deepcopy(getattr(conversation_module, conv_template_name))
return conv_template

async def send_request(self, url, data, delay=0):
async def send_request(self, url, data, delay=0, timeout=None):
if timeout is None:
timeout = self.max_time
await asyncio.sleep(delay)
timeout_settings = aiohttp.ClientTimeout(total=timeout) # Set the total timeout
async_sem = AsyncNullContext() if self.async_sem is None else self.async_sem
async with async_sem: # semaphore limits num of simultaneous downloads
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(timeout=timeout_settings) as session:
async with session.post(url, json=data) as resp:
print("headers: %s" % resp.headers, flush=True)
if resp.headers['Content-Type'] == 'application/json':
output = await resp.json()
else:
output = await resp.text()
output_text = await resp.text()
output = {"text": output_text}
if resp.status == 504:
print(f"504 Response received from {url}: {output}", flush=True)
raise TimeoutError(output_text)
return output

def setup_call(self, prompt):
Expand All @@ -1573,15 +1588,15 @@ def setup_call(self, prompt):

conv_template_name = self.inference_server.split(':')[1]
conv_template = self.get_conv_template(conv_template_name)
user_role = conv_template.roles[0]
assistant_role = conv_template.roles[1]
self.user_role = conv_template.roles[0]
self.assistant_role = conv_template.roles[1]
if self.system_prompt:
if not conv_template.system:
# assume means can't handle if didn't exist in template
conv_template.append_message(role=user_role, message=self.user_prompt_for_fake_system_prompt)
conv_template.append_message(role=self.user_role, message=self.user_prompt_for_fake_system_prompt)
if self.system_prompt == 'auto':
self.system_prompt = 'You are a helpful assistant.' if not self.image_file else "You are helpful visual LLM assistant capable of understanding text and images."
conv_template.append_message(role=assistant_role, message=self.system_prompt)
conv_template.append_message(role=self.assistant_role, message=self.system_prompt)
else:
our_system_prompt = False
if our_system_prompt:
Expand All @@ -1594,15 +1609,15 @@ def setup_call(self, prompt):
conv_template.append_message(role="system", message=self.system_prompt)
for message in self.chat_conversation:
if isinstance(message[0], str) and message[0]:
conv_template.append_message(role=user_role, message=message[0])
conv_template.append_message(role=self.user_role, message=message[0])
if isinstance(message[1], str) and message[1]:
conv_template.append_message(role=assistant_role, message=message[1])
conv_template.append_message(role=self.assistant_role, message=message[1])

conv_template_before_prompt = copy.deepcopy(conv_template)
self.conv_template_before_prompt = copy.deepcopy(conv_template)

prompt_with_image = f"<image>\n{prompt}"
conv_template.append_message(role=user_role, message=prompt_with_image)
conv_template.append_message(role=assistant_role, message=None)
conv_template.append_message(role=self.user_role, message=prompt_with_image)
conv_template.append_message(role=self.assistant_role, message=None)
prompt_with_template = conv_template.get_prompt()
if self.context:
prompt_with_template = self.context + prompt_with_template
Expand All @@ -1611,7 +1626,7 @@ def setup_call(self, prompt):
presence_penalty = (self.repetition_penalty - 1.0) * 2.0 + 0.0 # so good default

terminate_response = update_terminate_responses([], tokenizer=self.tokenizer)
pload = {
self.pload = {
"text": prompt_with_template,
"sampling_params": {
"max_new_tokens": self.max_new_tokens,
Expand All @@ -1624,31 +1639,39 @@ def setup_call(self, prompt):
"image_data": self.image_file[0],
"stream": self.stream_output,
}
url = self.inference_server_url + "/generate"
self.url = self.inference_server_url + "/generate"

if len(self.image_file) > 1:
# deal with all images
# also contains prompt_tokens, completion_tokens, finish_reason, etc.
responses = asyncio.run(self.get_many(url, pload))
def do_many(self):
# deal with all images
# also contains prompt_tokens, completion_tokens, finish_reason, etc.
return asyncio.run(self.get_many(self.url, self.pload))

async def a_do_many(self):
return await self.get_many(self.url, self.pload)

def many_to_prompt(self, prompt, responses):
if len(self.image_file) > 1:
# now use all those in final prompt
responses_context = '\n\n'.join(['# Image %d Answer\n\n%s\n\n' % (i, r['text']) for i, r in
enumerate(responses)])
prompt_with_responses = f"{responses_context}\n{prompt}"
conv_template_before_prompt.append_message(role=user_role, message=prompt_with_responses)
conv_template.append_message(role=assistant_role, message=None)
prompt_with_template = conv_template_before_prompt.get_prompt()
self.conv_template_before_prompt.append_message(role=self.user_role, message=prompt_with_responses)
self.conv_template_before_prompt.append_message(role=self.assistant_role, message=None)
prompt_with_template = self.conv_template_before_prompt.get_prompt()
if self.context:
prompt_with_template = self.context + prompt_with_template
self.prompts.append(prompt_with_template)
pload.pop('image_data') # no longer have images

response = requests.post(
url,
json=pload,
# update pload
self.pload['text'] = prompt_with_template # prompt now has response per image as single prompt
self.pload.pop('image_data') # no longer have images, just text

def do_final(self):
return requests.post(
self.url,
json=self.pload,
stream=self.stream_output,
)
return response

async def get_many(self, url, pload):
pload_no_image = pload.copy()
Expand All @@ -1673,7 +1696,11 @@ def _call(
if self.verbose:
print("_call", flush=True)

response = self.setup_call(prompt)
self.setup_call(prompt)
if len(self.image_file) > 1:
responses = self.do_many()
self.many_to_prompt(prompt, responses)
response = self.do_final()

if not self.stream_output:
response = response.json()['text']
Expand Down Expand Up @@ -1731,7 +1758,11 @@ async def _acall(
if self.verbose:
print("_call", flush=True)

response = self.setup_call(prompt)
self.setup_call(prompt)
if len(self.image_file) > 1:
responses = await self.a_do_many()
self.many_to_prompt(prompt, responses)
response = self.do_final()

if not self.stream_output:
response = response.json()['text']
Expand Down Expand Up @@ -3252,6 +3283,7 @@ def get_llm(use_openai_model=False,
callbacks = [streaming_callback]
streamer = callbacks[0] if stream_output else None

num_async = min(2, num_async) # can't handle as much
async_sem = asyncio.Semaphore(num_async) if async_output else AsyncNullContext()

if is_vision_model(model_name):
Expand Down
2 changes: 1 addition & 1 deletion src/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0249c1dce0a4195d88bd449615c1867d7b627529"
__version__ = "74bfb74acd5bf9bb93da9cc5d6fa1df6eb569657"

0 comments on commit c7f8b8f

Please sign in to comment.