Skip to content

Commit

Permalink
update UI
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Jan 4, 2024
1 parent 178321f commit e927f74
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 15 deletions.
19 changes: 17 additions & 2 deletions ragna/_compat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import builtins
import sys
from typing import Callable, Iterable, Iterator, Mapping, TypeVar
from typing import AsyncIterator, Callable, Iterable, Iterator, Mapping, TypeVar

__all__ = ["itertools_pairwise", "importlib_metadata_package_distributions"]
__all__ = ["itertools_pairwise", "importlib_metadata_package_distributions", "anext"]

T = TypeVar("T")

Expand Down Expand Up @@ -38,3 +39,17 @@ def _importlib_metadata_package_distributions() -> (


importlib_metadata_package_distributions = _importlib_metadata_package_distributions()


def _anext() -> Callable[[AsyncIterator[T]], T]:
if sys.version_info[:2] >= (3, 10):
anext = builtins.anext
else:

async def anext(ait: AsyncIterator[T]) -> T:
return await ait.__anext__()

return anext


anext = _anext()
21 changes: 10 additions & 11 deletions ragna/deploy/_ui/api_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
from datetime import datetime

import emoji
import httpx
import httpx_sse
import param


Expand Down Expand Up @@ -62,17 +64,14 @@ async def get_chats(self):
return json_data

async def answer(self, chat_id, prompt):
return self.improve_message(
(
await self.client.post(
f"/chats/{chat_id}/answer",
params={"prompt": prompt},
timeout=None,
)
)
.raise_for_status()
.json()
)
async with httpx_sse.aconnect_sse(
self.client,
"POST",
f"/chats/{chat_id}/answer",
json={"prompt": prompt, "stream": True},
) as event_source:
async for sse in event_source.aiter_sse():
yield self.improve_message(json.loads(sse.data))

async def get_components(self):
return (await self.client.get("/components")).raise_for_status().json()
Expand Down
12 changes: 10 additions & 2 deletions ragna/deploy/_ui/central_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import param
from panel.reactive import ReactiveHTML

from ragna._compat import anext

from . import styles as ui

# TODO : move all the CSS rules in a dedicated file
Expand Down Expand Up @@ -370,15 +372,21 @@ async def chat_callback(
self, content: str, user: str, instance: pn.chat.ChatInterface
):
try:
answer = await self.api_wrapper.answer(self.current_chat["id"], content)
answer_stream = self.api_wrapper.answer(self.current_chat["id"], content)
answer = await anext(answer_stream)

yield RagnaChatMessage(
message = RagnaChatMessage(
answer["content"],
role="assistant",
user=self.get_user_from_role("assistant"),
sources=answer["sources"],
on_click_source_info_callback=self.on_click_source_info_wrapper,
)
yield message

async for chunk in answer_stream:
message.object += chunk["content"]

except Exception:
yield RagnaChatMessage(
(
Expand Down

0 comments on commit e927f74

Please sign in to comment.