Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
no tools for stream

fix

fixes

support non stream for chat

works
  • Loading branch information
valaises committed Jun 11, 2024
1 parent 3481ff8 commit 0332cc6
Showing 1 changed file with 88 additions and 56 deletions.
144 changes: 88 additions & 56 deletions refact_webgui/webgui/selfhost_fastapi_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,19 @@ class NlpCompletion(NlpSamplingParams):
class ChatMessage(BaseModel):
role: str
content: str
# TODO: validate using pydantic
tool_calls: Optional[List[Dict[str, Any]]] = None
tool_call_id: Optional[str] = None


class ChatContext(NlpSamplingParams):
model: str = Query(pattern="^[a-z/A-Z0-9_\.\-]+$")
messages: List[ChatMessage]
# TODO: validate using pydantic
tools: Optional[List[Dict[str, Any]]] = None
tool_choice: Optional[str] = None
stream: bool = True
stop: Optional[Any]
n: int = 1


Expand Down Expand Up @@ -479,70 +487,94 @@ async def _models(self, authorization: str = Header(None)):
}

async def _chat_completions(self, post: ChatContext, authorization: str = Header(None)):
account = await self._account_from_bearer(authorization)

_account = await self._account_from_bearer(authorization)
messages = [m.dict() for m in post.messages]
prefix, postfix = "data: ", "\n\n"
model_dict = self._model_assigner.models_db_with_passthrough.get(post.model, {})

if model_dict.get('backend') == 'litellm' and (model_name := model_dict.get('resolve_as', post.model)) in litellm.model_list:
log(f"chat/completions: model resolve {post.model} -> {model_name}")

async def litellm_streamer(post: ChatContext):
async def litellm_streamer():
try:
self._integrations_env_setup()
response = await litellm.acompletion(
model=model_name, messages=messages, stream=True,
temperature=post.temperature, top_p=post.top_p,
max_tokens=min(model_dict.get('T_out', post.max_tokens), post.max_tokens),
tools=post.tools,
tool_choice=post.tool_choice,
stop=post.stop
)
finish_reason = None
async for model_response in response:
try:
data = model_response.dict()
finish_reason = data["choices"][0]["finish_reason"]
except json.JSONDecodeError:
data = {"choices": [{"finish_reason": finish_reason}]}
yield prefix + json.dumps(data) + postfix
# NOTE: DONE needed by refact-lsp server
yield prefix + "[DONE]" + postfix
except BaseException as e:
err_msg = f"litellm error: {e}"
log(err_msg)
yield prefix + json.dumps({"error": err_msg}) + postfix

async def litellm_non_streamer():
try:
self._integrations_env_setup()
model_response = await litellm.acompletion(
model=model_name, messages=messages, stream=False,
temperature=post.temperature, top_p=post.top_p,
max_tokens=min(model_dict.get('T_out', post.max_tokens), post.max_tokens),
tools=post.tools,
tool_choice=post.tool_choice,
stop=post.stop
)
finish_reason = None
try:
data = model_response.dict()
finish_reason = data["choices"][0]["finish_reason"]
except json.JSONDecodeError:
data = {"choices": [{"finish_reason": finish_reason}]}
yield json.dumps(data)
except BaseException as e:
err_msg = f"litellm error: {e}"
log(err_msg)
yield json.dumps({"error": err_msg})

async def chat_completion_streamer():
post_url = "http://127.0.0.1:8001/v1/chat"
payload = {
"messages": messages,
"stream": True,
"model": post.model,
"parameters": {
"temperature": post.temperature,
"max_new_tokens": post.max_tokens,
}
}
async with aiohttp.ClientSession() as session:
try:
self._integrations_env_setup()
response = await litellm.acompletion(
model=model_name, messages=[m.dict() for m in post.messages], stream=True,
temperature=post.temperature, top_p=post.top_p,
max_tokens=min(model_dict.get('T_out', post.max_tokens), post.max_tokens),
stop=post.stop)
finish_reason = None
async for model_response in response:
try:
data = model_response.dict()
finish_reason = data["choices"][0]["finish_reason"]
except json.JSONDecodeError:
data = {"choices": [{"finish_reason": finish_reason}]}
yield prefix + json.dumps(data) + postfix
# NOTE: DONE needed by refact-lsp server
yield prefix + "[DONE]" + postfix
except BaseException as e:
err_msg = f"litellm error: {e}"
async with session.post(post_url, json=payload) as response:
finish_reason = None
async for data, _ in response.content.iter_chunks():
try:
data = data.decode("utf-8")
data = json.loads(data[len(prefix):-len(postfix)])
finish_reason = data["choices"][0]["finish_reason"]
data["choices"][0]["finish_reason"] = None
except json.JSONDecodeError:
data = {"choices": [{"finish_reason": finish_reason}]}
yield prefix + json.dumps(data) + postfix
except aiohttp.ClientConnectorError as e:
err_msg = f"LSP server is not ready yet: {e}"
log(err_msg)
yield prefix + json.dumps({"error": err_msg}) + postfix

response_streamer = litellm_streamer(post)

if model_dict.get('backend') == 'litellm' and (model_name := model_dict.get('resolve_as', post.model)) in litellm.model_list:
log(f"chat/completions: model resolve {post.model} -> {model_name}")
response_streamer = litellm_streamer() if post.stream else litellm_non_streamer()
else:
async def chat_completion_streamer(post: ChatContext):
post_url = "http://127.0.0.1:8001/v1/chat"
post_data = {
"messages": [m.dict() for m in post.messages],
"stream": True,
"model": post.model,
"parameters": {
"temperature": post.temperature,
"max_new_tokens": post.max_tokens,
}
}
async with aiohttp.ClientSession() as session:
try:
async with session.post(post_url, json=post_data) as response:
finish_reason = None
async for data, _ in response.content.iter_chunks():
try:
data = data.decode("utf-8")
data = json.loads(data[len(prefix):-len(postfix)])
finish_reason = data["choices"][0]["finish_reason"]
data["choices"][0]["finish_reason"] = None
except json.JSONDecodeError:
data = {"choices": [{"finish_reason": finish_reason}]}
yield prefix + json.dumps(data) + postfix
except aiohttp.ClientConnectorError as e:
err_msg = f"LSP server is not ready yet: {e}"
log(err_msg)
yield prefix + json.dumps({"error": err_msg}) + postfix

response_streamer = chat_completion_streamer(post)
response_streamer = chat_completion_streamer()

return StreamingResponse(response_streamer, media_type="text/event-stream")

Expand Down

0 comments on commit 0332cc6

Please sign in to comment.