diff --git a/intel_extension_for_transformers/neural_chat/server/restful/textchat_api.py b/intel_extension_for_transformers/neural_chat/server/restful/textchat_api.py index 269a8af82d1..4ebac846c9d 100644 --- a/intel_extension_for_transformers/neural_chat/server/restful/textchat_api.py +++ b/intel_extension_for_transformers/neural_chat/server/restful/textchat_api.py @@ -18,6 +18,8 @@ from http import HTTPStatus import shortuuid import asyncio +import requests +from concurrent import futures from fastapi.routing import APIRouter from fastapi.responses import JSONResponse, StreamingResponse, Response from ...models.base_model import BaseModel @@ -481,142 +483,228 @@ async def create_chat_completion(request: ChatCompletionRequest): See https://platform.openai.com/docs/api-reference/chat/create for the API specification. """ - error_check_ret = await check_model(request) - if error_check_ret is not None: - return error_check_ret - error_check_ret = check_requests(request) - if error_check_ret is not None: - return error_check_ret - - chatbot = router.get_chatbot() - - if not is_openai_model(chatbot.model_name.lower()): - gen_params = await get_generation_parameters( - request.model, - chatbot, - request.messages, - temperature=request.temperature, - top_p=request.top_p, - top_k=request.top_k, - repetition_penalty=request.repetition_penalty, - presence_penalty=request.presence_penalty, - frequency_penalty=request.frequency_penalty, - max_tokens=request.max_tokens if request.max_tokens else 512, - echo=False, - stop=request.stop, - ) - else: - gen_params = { - "prompt": request.messages, - "temperature": request.temperature, - "top_p": request.top_p, - "repetition_penalty": request.repetition_penalty, - "max_new_tokens": request.max_tokens, - } - - if request.stream: - generator = chat_completion_stream_generator( - request.model, gen_params, request.n, chatbot - ) - return StreamingResponse(generator, media_type="text/event-stream") - - choices = [] - chat_completions = [] - for i in range(request.n): - content = asyncio.create_task(generate_completion(gen_params, chatbot)) - chat_completions.append(content) - try: - all_tasks = await asyncio.gather(*chat_completions) - except Exception as e: - return create_error_response(ApiErrorCode.INTERNAL_ERROR, str(e)) - usage = UsageInfo() - for i, content in enumerate(all_tasks): - if isinstance(content, str): - content = json.loads(content) - - content_string = content["text"] - if content["error_code"] != 0: - return create_error_response(content["error_code"], content_string) - choices.append( - ChatCompletionResponseChoice( - index=i, - message=ChatMessage(role="assistant", content=content_string), - finish_reason=content.get("finish_reason", "stop"), - ) - ) - if "usage" in content: - task_usage = UsageInfo.parse_obj(content["usage"]) - for usage_key, usage_value in task_usage.dict().items(): - setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + if router.use_deepspeed: + if request.stream: + responses = [] + def generate_stream(port): + url = f'http://{router.host}:{port}/v1/chat/completions' + response = requests.post(url, json=request.dict(), stream=True, timeout=1000) + responses.append(response) + with futures.ThreadPoolExecutor(max_workers=router.world_size) as executor: + worker_ports = [router.port + i + 1 for i in range(router.world_size)] + executor.map(generate_stream, worker_ports) + + while not responses: + pass + def generate(): + if responses[0]: + for chunk in responses[0].iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + yield f"data: {chunk}\n\n" + yield f"data: [DONE]\n\n" - return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) + return StreamingResponse(generate(), media_type="text/event-stream") + else: + responses = [] + + def send_request(port): + try: + url = f'http://{router.host}:{port}/v1/chat/completions' + response = requests.post(url, json=request.dict()) + response.raise_for_status() + json_response = json.loads(response.content) + chat_completion_response = ChatCompletionResponse(response=json_response['response']) + responses.append(chat_completion_response) + except requests.exceptions.RequestException as e: + print(f"Error sending/receiving on port {port}: {e}") + + with futures.ThreadPoolExecutor(max_workers=router.world_size) as executor: + worker_ports = [router.port + i + 1 for i in range(router.world_size)] + executor.map(send_request, worker_ports) + if responses: + return responses[0] -@router.post("/v1/completions") -async def create_completion(request: CompletionRequest): - error_check_ret = await check_model(request) - if error_check_ret is not None: - return error_check_ret - error_check_ret = check_requests(request) - if error_check_ret is not None: - return error_check_ret - - chatbot = router.get_chatbot() - request.prompt = process_input(request.model, request.prompt) - - if request.stream: - generator = generate_completion_stream_generator( - request, request.n, chatbot - ) - return StreamingResponse(generator, media_type="text/event-stream") else: - text_completions = [] - for text in request.prompt: + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + error_check_ret = check_requests(request) + if error_check_ret is not None: + return error_check_ret + + chatbot = router.get_chatbot() + + if not is_openai_model(chatbot.model_name.lower()): gen_params = await get_generation_parameters( request.model, chatbot, - text, + request.messages, temperature=request.temperature, top_p=request.top_p, top_k=request.top_k, repetition_penalty=request.repetition_penalty, - frequency_penalty=request.frequency_penalty, presence_penalty=request.presence_penalty, - max_tokens=request.max_tokens, - logprobs=request.logprobs, - echo=request.echo, + frequency_penalty=request.frequency_penalty, + max_tokens=request.max_tokens if request.max_tokens else 512, + echo=False, stop=request.stop, - best_of=request.best_of, - use_beam_search=request.use_beam_search, ) - for i in range(request.n): - content = asyncio.create_task( - generate_completion(gen_params, chatbot) - ) - text_completions.append(content) + else: + gen_params = { + "prompt": request.messages, + "temperature": request.temperature, + "top_p": request.top_p, + "repetition_penalty": request.repetition_penalty, + "max_new_tokens": request.max_tokens, + } + + if request.stream: + generator = chat_completion_stream_generator( + request.model, gen_params, request.n, chatbot + ) + return StreamingResponse(generator, media_type="text/event-stream") + choices = [] + chat_completions = [] + for i in range(request.n): + content = asyncio.create_task(generate_completion(gen_params, chatbot)) + chat_completions.append(content) try: - all_tasks = await asyncio.gather(*text_completions) + all_tasks = await asyncio.gather(*chat_completions) except Exception as e: return create_error_response(ApiErrorCode.INTERNAL_ERROR, str(e)) - - choices = [] usage = UsageInfo() for i, content in enumerate(all_tasks): + if isinstance(content, str): + content = json.loads(content) + + content_string = content["text"] if content["error_code"] != 0: - return create_error_response(content["error_code"], content["text"]) + return create_error_response(content["error_code"], content_string) choices.append( - CompletionResponseChoice( + ChatCompletionResponseChoice( index=i, - text=content["text"], - logprobs=create_openai_logprobs(content.get("logprobs", None)), + message=ChatMessage(role="assistant", content=content_string), finish_reason=content.get("finish_reason", "stop"), ) ) - task_usage = UsageInfo.parse_obj(content["usage"]) - for usage_key, usage_value in task_usage.dict().items(): - setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + if "usage" in content: + task_usage = UsageInfo.parse_obj(content["usage"]) + for usage_key, usage_value in task_usage.dict().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) - return CompletionResponse( - model=request.model, choices=choices, usage=UsageInfo.parse_obj(usage) - ) + return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) + + +@router.post("/v1/completions") +async def create_completion(request: CompletionRequest): + if router.use_deepspeed: + if request.stream: + responses = [] + def generate_stream(port): + url = f'http://{router.host}:{port}/v1/completions' + response = requests.post(url, json=request.dict(), stream=True, timeout=1000) + responses.append(response) + with futures.ThreadPoolExecutor(max_workers=router.world_size) as executor: + worker_ports = [router.port + i + 1 for i in range(router.world_size)] + executor.map(generate_stream, worker_ports) + + while not responses: + pass + def generate(): + if responses[0]: + for chunk in responses[0].iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + yield f"data: {chunk}\n\n" + yield f"data: [DONE]\n\n" + + return StreamingResponse(generate(), media_type="text/event-stream") + + else: + responses = [] + + def send_request(port): + try: + url = f'http://{router.host}:{port}/v1/completions' + response = requests.post(url, json=request.dict()) + response.raise_for_status() + json_response = json.loads(response.content) + chat_completion_response = ChatCompletionResponse(response=json_response['response']) + responses.append(chat_completion_response) + except requests.exceptions.RequestException as e: + print(f"Error sending/receiving on port {port}: {e}") + + with futures.ThreadPoolExecutor(max_workers=router.world_size) as executor: + worker_ports = [router.port + i + 1 for i in range(router.world_size)] + executor.map(send_request, worker_ports) + if responses: + return responses[0] + + else: + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + error_check_ret = check_requests(request) + if error_check_ret is not None: + return error_check_ret + + chatbot = router.get_chatbot() + request.prompt = process_input(request.model, request.prompt) + + if request.stream: + generator = generate_completion_stream_generator( + request, request.n, chatbot + ) + return StreamingResponse(generator, media_type="text/event-stream") + else: + text_completions = [] + for text in request.prompt: + gen_params = await get_generation_parameters( + request.model, + chatbot, + text, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + repetition_penalty=request.repetition_penalty, + frequency_penalty=request.frequency_penalty, + presence_penalty=request.presence_penalty, + max_tokens=request.max_tokens, + logprobs=request.logprobs, + echo=request.echo, + stop=request.stop, + best_of=request.best_of, + use_beam_search=request.use_beam_search, + ) + for i in range(request.n): + content = asyncio.create_task( + generate_completion(gen_params, chatbot) + ) + text_completions.append(content) + + try: + all_tasks = await asyncio.gather(*text_completions) + except Exception as e: + return create_error_response(ApiErrorCode.INTERNAL_ERROR, str(e)) + + choices = [] + usage = UsageInfo() + for i, content in enumerate(all_tasks): + if content["error_code"] != 0: + return create_error_response(content["error_code"], content["text"]) + choices.append( + CompletionResponseChoice( + index=i, + text=content["text"], + logprobs=create_openai_logprobs(content.get("logprobs", None)), + finish_reason=content.get("finish_reason", "stop"), + ) + ) + task_usage = UsageInfo.parse_obj(content["usage"]) + for usage_key, usage_value in task_usage.dict().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + return CompletionResponse( + model=request.model, choices=choices, usage=UsageInfo.parse_obj(usage) + )