Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tooluse #430

Merged
merged 1 commit into from
Jun 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 88 additions & 56 deletions refact_webgui/webgui/selfhost_fastapi_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,19 @@ class NlpCompletion(NlpSamplingParams):
class ChatMessage(BaseModel):
role: str
content: str
# TODO: validate using pydantic
tool_calls: Optional[List[Dict[str, Any]]] = None
tool_call_id: Optional[str] = None


class ChatContext(NlpSamplingParams):
model: str = Query(pattern="^[a-z/A-Z0-9_\.\-]+$")
messages: List[ChatMessage]
# TODO: validate using pydantic
tools: Optional[List[Dict[str, Any]]] = None
tool_choice: Optional[str] = None
stream: bool = True
stop: Optional[Any]
n: int = 1


Expand Down Expand Up @@ -479,70 +487,94 @@ async def _models(self, authorization: str = Header(None)):
}

async def _chat_completions(self, post: ChatContext, authorization: str = Header(None)):
account = await self._account_from_bearer(authorization)

_account = await self._account_from_bearer(authorization)
messages = [m.dict() for m in post.messages]
prefix, postfix = "data: ", "\n\n"
model_dict = self._model_assigner.models_db_with_passthrough.get(post.model, {})

if model_dict.get('backend') == 'litellm' and (model_name := model_dict.get('resolve_as', post.model)) in litellm.model_list:
log(f"chat/completions: model resolve {post.model} -> {model_name}")

async def litellm_streamer(post: ChatContext):
async def litellm_streamer():
try:
self._integrations_env_setup()
response = await litellm.acompletion(
model=model_name, messages=messages, stream=True,
temperature=post.temperature, top_p=post.top_p,
max_tokens=min(model_dict.get('T_out', post.max_tokens), post.max_tokens),
tools=post.tools,
tool_choice=post.tool_choice,
stop=post.stop
)
finish_reason = None
async for model_response in response:
try:
data = model_response.dict()
finish_reason = data["choices"][0]["finish_reason"]
except json.JSONDecodeError:
data = {"choices": [{"finish_reason": finish_reason}]}
yield prefix + json.dumps(data) + postfix
# NOTE: DONE needed by refact-lsp server
yield prefix + "[DONE]" + postfix
except BaseException as e:
err_msg = f"litellm error: {e}"
log(err_msg)
yield prefix + json.dumps({"error": err_msg}) + postfix

async def litellm_non_streamer():
try:
self._integrations_env_setup()
model_response = await litellm.acompletion(
model=model_name, messages=messages, stream=False,
temperature=post.temperature, top_p=post.top_p,
max_tokens=min(model_dict.get('T_out', post.max_tokens), post.max_tokens),
tools=post.tools,
tool_choice=post.tool_choice,
stop=post.stop
)
finish_reason = None
try:
data = model_response.dict()
finish_reason = data["choices"][0]["finish_reason"]
except json.JSONDecodeError:
data = {"choices": [{"finish_reason": finish_reason}]}
yield json.dumps(data)
except BaseException as e:
err_msg = f"litellm error: {e}"
log(err_msg)
yield json.dumps({"error": err_msg})

async def chat_completion_streamer():
post_url = "http://127.0.0.1:8001/v1/chat"
payload = {
"messages": messages,
"stream": True,
"model": post.model,
"parameters": {
"temperature": post.temperature,
"max_new_tokens": post.max_tokens,
}
}
async with aiohttp.ClientSession() as session:
try:
self._integrations_env_setup()
response = await litellm.acompletion(
model=model_name, messages=[m.dict() for m in post.messages], stream=True,
temperature=post.temperature, top_p=post.top_p,
max_tokens=min(model_dict.get('T_out', post.max_tokens), post.max_tokens),
stop=post.stop)
finish_reason = None
async for model_response in response:
try:
data = model_response.dict()
finish_reason = data["choices"][0]["finish_reason"]
except json.JSONDecodeError:
data = {"choices": [{"finish_reason": finish_reason}]}
yield prefix + json.dumps(data) + postfix
# NOTE: DONE needed by refact-lsp server
yield prefix + "[DONE]" + postfix
except BaseException as e:
err_msg = f"litellm error: {e}"
async with session.post(post_url, json=payload) as response:
finish_reason = None
async for data, _ in response.content.iter_chunks():
try:
data = data.decode("utf-8")
data = json.loads(data[len(prefix):-len(postfix)])
finish_reason = data["choices"][0]["finish_reason"]
data["choices"][0]["finish_reason"] = None
except json.JSONDecodeError:
data = {"choices": [{"finish_reason": finish_reason}]}
yield prefix + json.dumps(data) + postfix
except aiohttp.ClientConnectorError as e:
err_msg = f"LSP server is not ready yet: {e}"
log(err_msg)
yield prefix + json.dumps({"error": err_msg}) + postfix

response_streamer = litellm_streamer(post)

if model_dict.get('backend') == 'litellm' and (model_name := model_dict.get('resolve_as', post.model)) in litellm.model_list:
log(f"chat/completions: model resolve {post.model} -> {model_name}")
response_streamer = litellm_streamer() if post.stream else litellm_non_streamer()
else:
async def chat_completion_streamer(post: ChatContext):
post_url = "http://127.0.0.1:8001/v1/chat"
post_data = {
"messages": [m.dict() for m in post.messages],
"stream": True,
"model": post.model,
"parameters": {
"temperature": post.temperature,
"max_new_tokens": post.max_tokens,
}
}
async with aiohttp.ClientSession() as session:
try:
async with session.post(post_url, json=post_data) as response:
finish_reason = None
async for data, _ in response.content.iter_chunks():
try:
data = data.decode("utf-8")
data = json.loads(data[len(prefix):-len(postfix)])
finish_reason = data["choices"][0]["finish_reason"]
data["choices"][0]["finish_reason"] = None
except json.JSONDecodeError:
data = {"choices": [{"finish_reason": finish_reason}]}
yield prefix + json.dumps(data) + postfix
except aiohttp.ClientConnectorError as e:
err_msg = f"LSP server is not ready yet: {e}"
log(err_msg)
yield prefix + json.dumps({"error": err_msg}) + postfix

response_streamer = chat_completion_streamer(post)
response_streamer = chat_completion_streamer()

return StreamingResponse(response_streamer, media_type="text/event-stream")

Expand Down