diff --git a/refact_webgui/webgui/selfhost_fastapi_completions.py b/refact_webgui/webgui/selfhost_fastapi_completions.py index 5c12e476..ec3f7869 100644 --- a/refact_webgui/webgui/selfhost_fastapi_completions.py +++ b/refact_webgui/webgui/selfhost_fastapi_completions.py @@ -72,11 +72,19 @@ class NlpCompletion(NlpSamplingParams): class ChatMessage(BaseModel): role: str content: str + # TODO: validate using pydantic + tool_calls: Optional[List[Dict[str, Any]]] = None + tool_call_id: Optional[str] = None class ChatContext(NlpSamplingParams): model: str = Query(pattern="^[a-z/A-Z0-9_\.\-]+$") messages: List[ChatMessage] + # TODO: validate using pydantic + tools: Optional[List[Dict[str, Any]]] = None + tool_choice: Optional[str] = None + stream: bool = True + stop: Optional[Any] n: int = 1 @@ -479,70 +487,94 @@ async def _models(self, authorization: str = Header(None)): } async def _chat_completions(self, post: ChatContext, authorization: str = Header(None)): - account = await self._account_from_bearer(authorization) - + _account = await self._account_from_bearer(authorization) + messages = [m.dict() for m in post.messages] prefix, postfix = "data: ", "\n\n" model_dict = self._model_assigner.models_db_with_passthrough.get(post.model, {}) - if model_dict.get('backend') == 'litellm' and (model_name := model_dict.get('resolve_as', post.model)) in litellm.model_list: - log(f"chat/completions: model resolve {post.model} -> {model_name}") - - async def litellm_streamer(post: ChatContext): + async def litellm_streamer(): + try: + self._integrations_env_setup() + response = await litellm.acompletion( + model=model_name, messages=messages, stream=True, + temperature=post.temperature, top_p=post.top_p, + max_tokens=min(model_dict.get('T_out', post.max_tokens), post.max_tokens), + tools=post.tools, + tool_choice=post.tool_choice, + stop=post.stop + ) + finish_reason = None + async for model_response in response: + try: + data = model_response.dict() + finish_reason = data["choices"][0]["finish_reason"] + except json.JSONDecodeError: + data = {"choices": [{"finish_reason": finish_reason}]} + yield prefix + json.dumps(data) + postfix + # NOTE: DONE needed by refact-lsp server + yield prefix + "[DONE]" + postfix + except BaseException as e: + err_msg = f"litellm error: {e}" + log(err_msg) + yield prefix + json.dumps({"error": err_msg}) + postfix + + async def litellm_non_streamer(): + try: + self._integrations_env_setup() + model_response = await litellm.acompletion( + model=model_name, messages=messages, stream=False, + temperature=post.temperature, top_p=post.top_p, + max_tokens=min(model_dict.get('T_out', post.max_tokens), post.max_tokens), + tools=post.tools, + tool_choice=post.tool_choice, + stop=post.stop + ) + finish_reason = None + try: + data = model_response.dict() + finish_reason = data["choices"][0]["finish_reason"] + except json.JSONDecodeError: + data = {"choices": [{"finish_reason": finish_reason}]} + yield json.dumps(data) + except BaseException as e: + err_msg = f"litellm error: {e}" + log(err_msg) + yield json.dumps({"error": err_msg}) + + async def chat_completion_streamer(): + post_url = "http://127.0.0.1:8001/v1/chat" + payload = { + "messages": messages, + "stream": True, + "model": post.model, + "parameters": { + "temperature": post.temperature, + "max_new_tokens": post.max_tokens, + } + } + async with aiohttp.ClientSession() as session: try: - self._integrations_env_setup() - response = await litellm.acompletion( - model=model_name, messages=[m.dict() for m in post.messages], stream=True, - temperature=post.temperature, top_p=post.top_p, - max_tokens=min(model_dict.get('T_out', post.max_tokens), post.max_tokens), - stop=post.stop) - finish_reason = None - async for model_response in response: - try: - data = model_response.dict() - finish_reason = data["choices"][0]["finish_reason"] - except json.JSONDecodeError: - data = {"choices": [{"finish_reason": finish_reason}]} - yield prefix + json.dumps(data) + postfix - # NOTE: DONE needed by refact-lsp server - yield prefix + "[DONE]" + postfix - except BaseException as e: - err_msg = f"litellm error: {e}" + async with session.post(post_url, json=payload) as response: + finish_reason = None + async for data, _ in response.content.iter_chunks(): + try: + data = data.decode("utf-8") + data = json.loads(data[len(prefix):-len(postfix)]) + finish_reason = data["choices"][0]["finish_reason"] + data["choices"][0]["finish_reason"] = None + except json.JSONDecodeError: + data = {"choices": [{"finish_reason": finish_reason}]} + yield prefix + json.dumps(data) + postfix + except aiohttp.ClientConnectorError as e: + err_msg = f"LSP server is not ready yet: {e}" log(err_msg) yield prefix + json.dumps({"error": err_msg}) + postfix - response_streamer = litellm_streamer(post) - + if model_dict.get('backend') == 'litellm' and (model_name := model_dict.get('resolve_as', post.model)) in litellm.model_list: + log(f"chat/completions: model resolve {post.model} -> {model_name}") + response_streamer = litellm_streamer() if post.stream else litellm_non_streamer() else: - async def chat_completion_streamer(post: ChatContext): - post_url = "http://127.0.0.1:8001/v1/chat" - post_data = { - "messages": [m.dict() for m in post.messages], - "stream": True, - "model": post.model, - "parameters": { - "temperature": post.temperature, - "max_new_tokens": post.max_tokens, - } - } - async with aiohttp.ClientSession() as session: - try: - async with session.post(post_url, json=post_data) as response: - finish_reason = None - async for data, _ in response.content.iter_chunks(): - try: - data = data.decode("utf-8") - data = json.loads(data[len(prefix):-len(postfix)]) - finish_reason = data["choices"][0]["finish_reason"] - data["choices"][0]["finish_reason"] = None - except json.JSONDecodeError: - data = {"choices": [{"finish_reason": finish_reason}]} - yield prefix + json.dumps(data) + postfix - except aiohttp.ClientConnectorError as e: - err_msg = f"LSP server is not ready yet: {e}" - log(err_msg) - yield prefix + json.dumps({"error": err_msg}) + postfix - - response_streamer = chat_completion_streamer(post) + response_streamer = chat_completion_streamer() return StreamingResponse(response_streamer, media_type="text/event-stream")