Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
[NeuralChat] Support deepspeed for textchat API (#1443)
Browse files Browse the repository at this point in the history
  • Loading branch information
letonghan authored Apr 2, 2024
1 parent 1065d81 commit 7b0b995
Showing 1 changed file with 200 additions and 112 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from http import HTTPStatus
import shortuuid
import asyncio
import requests
from concurrent import futures
from fastapi.routing import APIRouter
from fastapi.responses import JSONResponse, StreamingResponse, Response
from ...models.base_model import BaseModel
Expand Down Expand Up @@ -481,142 +483,228 @@ async def create_chat_completion(request: ChatCompletionRequest):
See https://platform.openai.com/docs/api-reference/chat/create for the API specification.
"""
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
error_check_ret = check_requests(request)
if error_check_ret is not None:
return error_check_ret

chatbot = router.get_chatbot()

if not is_openai_model(chatbot.model_name.lower()):
gen_params = await get_generation_parameters(
request.model,
chatbot,
request.messages,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
repetition_penalty=request.repetition_penalty,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
max_tokens=request.max_tokens if request.max_tokens else 512,
echo=False,
stop=request.stop,
)
else:
gen_params = {
"prompt": request.messages,
"temperature": request.temperature,
"top_p": request.top_p,
"repetition_penalty": request.repetition_penalty,
"max_new_tokens": request.max_tokens,
}

if request.stream:
generator = chat_completion_stream_generator(
request.model, gen_params, request.n, chatbot
)
return StreamingResponse(generator, media_type="text/event-stream")

choices = []
chat_completions = []
for i in range(request.n):
content = asyncio.create_task(generate_completion(gen_params, chatbot))
chat_completions.append(content)
try:
all_tasks = await asyncio.gather(*chat_completions)
except Exception as e:
return create_error_response(ApiErrorCode.INTERNAL_ERROR, str(e))
usage = UsageInfo()
for i, content in enumerate(all_tasks):
if isinstance(content, str):
content = json.loads(content)

content_string = content["text"]
if content["error_code"] != 0:
return create_error_response(content["error_code"], content_string)
choices.append(
ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role="assistant", content=content_string),
finish_reason=content.get("finish_reason", "stop"),
)
)
if "usage" in content:
task_usage = UsageInfo.parse_obj(content["usage"])
for usage_key, usage_value in task_usage.dict().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
if router.use_deepspeed:
if request.stream:
responses = []
def generate_stream(port):
url = f'http://{router.host}:{port}/v1/chat/completions'
response = requests.post(url, json=request.dict(), stream=True, timeout=1000)
responses.append(response)
with futures.ThreadPoolExecutor(max_workers=router.world_size) as executor:
worker_ports = [router.port + i + 1 for i in range(router.world_size)]
executor.map(generate_stream, worker_ports)

while not responses:
pass
def generate():
if responses[0]:
for chunk in responses[0].iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
yield f"data: {chunk}\n\n"
yield f"data: [DONE]\n\n"

return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
return StreamingResponse(generate(), media_type="text/event-stream")

else:
responses = []

def send_request(port):
try:
url = f'http://{router.host}:{port}/v1/chat/completions'
response = requests.post(url, json=request.dict())
response.raise_for_status()
json_response = json.loads(response.content)
chat_completion_response = ChatCompletionResponse(response=json_response['response'])
responses.append(chat_completion_response)
except requests.exceptions.RequestException as e:
print(f"Error sending/receiving on port {port}: {e}")

with futures.ThreadPoolExecutor(max_workers=router.world_size) as executor:
worker_ports = [router.port + i + 1 for i in range(router.world_size)]
executor.map(send_request, worker_ports)
if responses:
return responses[0]

@router.post("/v1/completions")
async def create_completion(request: CompletionRequest):
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
error_check_ret = check_requests(request)
if error_check_ret is not None:
return error_check_ret

chatbot = router.get_chatbot()
request.prompt = process_input(request.model, request.prompt)

if request.stream:
generator = generate_completion_stream_generator(
request, request.n, chatbot
)
return StreamingResponse(generator, media_type="text/event-stream")
else:
text_completions = []
for text in request.prompt:
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
error_check_ret = check_requests(request)
if error_check_ret is not None:
return error_check_ret

chatbot = router.get_chatbot()

if not is_openai_model(chatbot.model_name.lower()):
gen_params = await get_generation_parameters(
request.model,
chatbot,
text,
request.messages,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
repetition_penalty=request.repetition_penalty,
frequency_penalty=request.frequency_penalty,
presence_penalty=request.presence_penalty,
max_tokens=request.max_tokens,
logprobs=request.logprobs,
echo=request.echo,
frequency_penalty=request.frequency_penalty,
max_tokens=request.max_tokens if request.max_tokens else 512,
echo=False,
stop=request.stop,
best_of=request.best_of,
use_beam_search=request.use_beam_search,
)
for i in range(request.n):
content = asyncio.create_task(
generate_completion(gen_params, chatbot)
)
text_completions.append(content)
else:
gen_params = {
"prompt": request.messages,
"temperature": request.temperature,
"top_p": request.top_p,
"repetition_penalty": request.repetition_penalty,
"max_new_tokens": request.max_tokens,
}

if request.stream:
generator = chat_completion_stream_generator(
request.model, gen_params, request.n, chatbot
)
return StreamingResponse(generator, media_type="text/event-stream")

choices = []
chat_completions = []
for i in range(request.n):
content = asyncio.create_task(generate_completion(gen_params, chatbot))
chat_completions.append(content)
try:
all_tasks = await asyncio.gather(*text_completions)
all_tasks = await asyncio.gather(*chat_completions)
except Exception as e:
return create_error_response(ApiErrorCode.INTERNAL_ERROR, str(e))

choices = []
usage = UsageInfo()
for i, content in enumerate(all_tasks):
if isinstance(content, str):
content = json.loads(content)

content_string = content["text"]
if content["error_code"] != 0:
return create_error_response(content["error_code"], content["text"])
return create_error_response(content["error_code"], content_string)
choices.append(
CompletionResponseChoice(
ChatCompletionResponseChoice(
index=i,
text=content["text"],
logprobs=create_openai_logprobs(content.get("logprobs", None)),
message=ChatMessage(role="assistant", content=content_string),
finish_reason=content.get("finish_reason", "stop"),
)
)
task_usage = UsageInfo.parse_obj(content["usage"])
for usage_key, usage_value in task_usage.dict().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
if "usage" in content:
task_usage = UsageInfo.parse_obj(content["usage"])
for usage_key, usage_value in task_usage.dict().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)

return CompletionResponse(
model=request.model, choices=choices, usage=UsageInfo.parse_obj(usage)
)
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)


@router.post("/v1/completions")
async def create_completion(request: CompletionRequest):
if router.use_deepspeed:
if request.stream:
responses = []
def generate_stream(port):
url = f'http://{router.host}:{port}/v1/completions'
response = requests.post(url, json=request.dict(), stream=True, timeout=1000)
responses.append(response)
with futures.ThreadPoolExecutor(max_workers=router.world_size) as executor:
worker_ports = [router.port + i + 1 for i in range(router.world_size)]
executor.map(generate_stream, worker_ports)

while not responses:
pass
def generate():
if responses[0]:
for chunk in responses[0].iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
yield f"data: {chunk}\n\n"
yield f"data: [DONE]\n\n"

return StreamingResponse(generate(), media_type="text/event-stream")

else:
responses = []

def send_request(port):
try:
url = f'http://{router.host}:{port}/v1/completions'
response = requests.post(url, json=request.dict())
response.raise_for_status()
json_response = json.loads(response.content)
chat_completion_response = ChatCompletionResponse(response=json_response['response'])
responses.append(chat_completion_response)
except requests.exceptions.RequestException as e:
print(f"Error sending/receiving on port {port}: {e}")

with futures.ThreadPoolExecutor(max_workers=router.world_size) as executor:
worker_ports = [router.port + i + 1 for i in range(router.world_size)]
executor.map(send_request, worker_ports)
if responses:
return responses[0]

else:
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
error_check_ret = check_requests(request)
if error_check_ret is not None:
return error_check_ret

chatbot = router.get_chatbot()
request.prompt = process_input(request.model, request.prompt)

if request.stream:
generator = generate_completion_stream_generator(
request, request.n, chatbot
)
return StreamingResponse(generator, media_type="text/event-stream")
else:
text_completions = []
for text in request.prompt:
gen_params = await get_generation_parameters(
request.model,
chatbot,
text,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
repetition_penalty=request.repetition_penalty,
frequency_penalty=request.frequency_penalty,
presence_penalty=request.presence_penalty,
max_tokens=request.max_tokens,
logprobs=request.logprobs,
echo=request.echo,
stop=request.stop,
best_of=request.best_of,
use_beam_search=request.use_beam_search,
)
for i in range(request.n):
content = asyncio.create_task(
generate_completion(gen_params, chatbot)
)
text_completions.append(content)

try:
all_tasks = await asyncio.gather(*text_completions)
except Exception as e:
return create_error_response(ApiErrorCode.INTERNAL_ERROR, str(e))

choices = []
usage = UsageInfo()
for i, content in enumerate(all_tasks):
if content["error_code"] != 0:
return create_error_response(content["error_code"], content["text"])
choices.append(
CompletionResponseChoice(
index=i,
text=content["text"],
logprobs=create_openai_logprobs(content.get("logprobs", None)),
finish_reason=content.get("finish_reason", "stop"),
)
)
task_usage = UsageInfo.parse_obj(content["usage"])
for usage_key, usage_value in task_usage.dict().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)

return CompletionResponse(
model=request.model, choices=choices, usage=UsageInfo.parse_obj(usage)
)

0 comments on commit 7b0b995

Please sign in to comment.