Skip to content

Commit

Permalink
Feature/remove retrievers (#1314)
Browse files Browse the repository at this point in the history
* rationalize metadata

* remove unused file

* bug fix

* bug fix

* removed retiever

* s3_keys -> documents

* bug fix

* fix bug

* formatting
  • Loading branch information
gecBurton authored Jan 23, 2025
1 parent 40a2042 commit 9f3811c
Show file tree
Hide file tree
Showing 11 changed files with 17 additions and 194 deletions.
Binary file removed django_app/files/RiskTriggersReport361.pdf
Binary file not shown.
23 changes: 9 additions & 14 deletions django_app/redbox_app/redbox_core/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from django.conf import settings
from django.contrib.auth import get_user_model
from django.forms.models import model_to_dict
from langchain_core.documents import Document
from openai import RateLimitError
from websockets import ConnectionClosedError, WebSocketClientProtocol

Expand All @@ -21,7 +22,6 @@
RedboxQuery,
RedboxState,
)
from redbox.retriever import DjangoFileRetriever
from redbox_app.redbox_core import error_messages
from redbox_app.redbox_core.models import AISettings as AISettingsModel
from redbox_app.redbox_core.models import (
Expand Down Expand Up @@ -56,15 +56,11 @@ def escape_curly_brackets(text: str):
class ChatConsumer(AsyncWebsocketConsumer):
full_reply: ClassVar = []
route = None
redbox = Redbox(
retriever=DjangoFileRetriever(file_manager=File.objects),
debug=True,
)
redbox = Redbox(debug=settings.DEBUG)

async def receive(self, text_data=None, bytes_data=None):
"""Receive & respond to message from browser websocket."""
self.full_reply = []
self.external_citations = []
self.route = None

data = json.loads(text_data or bytes_data)
Expand Down Expand Up @@ -104,17 +100,17 @@ async def receive(self, text_data=None, bytes_data=None):
return

# save user message
permitted_files = File.objects.filter(user=user, status=File.Status.complete)
selected_files = permitted_files.filter(id__in=selected_file_uuids) | session.file_set.all()
selected_files = (
File.objects.filter(user=user, status=File.Status.complete, id__in=selected_file_uuids)
| session.file_set.all()
)

await self.save_user_message(session, user_message_text, selected_files=selected_files)

await self.llm_conversation(selected_files, session, user, user_message_text, permitted_files)
await self.llm_conversation(selected_files, session, user, user_message_text)
await self.close()

async def llm_conversation(
self, selected_files: Sequence[File], session: Chat, user: User, title: str, permitted_files: Sequence[File]
) -> None:
async def llm_conversation(self, selected_files: Sequence[File], session: Chat, user: User, title: str) -> None:
"""Initiate & close websocket conversation with the core-api message endpoint."""
await self.send_to_client("session-id", session.id)

Expand All @@ -136,7 +132,7 @@ async def llm_conversation(
state = RedboxState(
request=RedboxQuery(
question=message_history[-1].text,
s3_keys=[f.unique_name for f in selected_files],
documents=[Document(str(f.text), metadata={"uri": f.original_file.name}) for f in selected_files],
user_uuid=user.id,
chat_history=[
ChainChatMessage(
Expand All @@ -146,7 +142,6 @@ async def llm_conversation(
for message in message_history[:-1]
],
ai_settings=ai_settings,
permitted_s3_keys=[f.unique_name async for f in permitted_files],
),
)

Expand Down
9 changes: 3 additions & 6 deletions django_app/tests/test_consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,11 +388,9 @@ async def test_chat_consumer_redbox_state(
assert connected

selected_file_uuids: Sequence[str] = [str(f.id) for f in selected_files]
selected_file_keys: Sequence[str] = [f.unique_name for f in selected_files]
permitted_file_keys: Sequence[str] = [
f.unique_name async for f in File.objects.filter(user=alice, status=File.Status.complete)
documents: Sequence[str] = [
Document(page_content=str(f.text), metadata={"uri": f.original_file.name}) for f in selected_files
]
assert selected_file_keys != permitted_file_keys

await communicator.send_json_to(
{
Expand All @@ -413,7 +411,7 @@ async def test_chat_consumer_redbox_state(
# Then
expected_request = RedboxQuery(
question="Third question, with selected files?",
s3_keys=selected_file_keys,
documents=documents,
user_uuid=alice.id,
chat_history=[
{"role": "user", "text": "A question?"},
Expand All @@ -422,7 +420,6 @@ async def test_chat_consumer_redbox_state(
{"role": "ai", "text": "A second answer."},
],
ai_settings=ai_settings,
permitted_s3_keys=permitted_file_keys,
)
redbox_state = mock_run.call_args.args[0] # pulls out the args that redbox.run was called with

Expand Down
4 changes: 1 addition & 3 deletions redbox-core/redbox/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@ async def _default_callback(*args, **kwargs):


class Redbox:
def __init__(self, retriever, debug: bool = False):
self.retriever = retriever
def __init__(self, debug: bool = False):
self.debug = debug

def _get_runnable(self, state: RedboxState):
return build_llm_chain(
prompt_set=PromptSet.ChatwithDocs,
retriever=self.retriever,
llm=get_chat_llm(state.request.ai_settings.chat_backend),
)

Expand Down
104 changes: 2 additions & 102 deletions redbox-core/redbox/chains/runnables.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,17 @@
import logging
import re
from typing import Any, Iterator

from langchain_core.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import (
Runnable,
RunnableLambda,
chain,
)
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from tiktoken import Encoding

from redbox.chains.components import get_tokeniser
from redbox.models.chain import ChainChatMessage, PromptSet, RedboxState, get_prompts
from redbox.models.errors import QuestionLengthError
from redbox.api.format import format_documents
from redbox.retriever.retrievers import retriever_runnable

log = logging.getLogger()
re_string_pattern = re.compile(r"(\S+)")
Expand Down Expand Up @@ -55,9 +46,6 @@ def _chat_prompt_from_messages(state: RedboxState) -> Runnable:
prompts_budget = len(_tokeniser.encode(task_system_prompt)) + len(_tokeniser.encode(task_question_prompt))
chat_history_budget = ai_settings.context_window_size - ai_settings.llm_max_tokens - prompts_budget

if chat_history_budget <= 0:
raise QuestionLengthError

truncated_history: list[ChainChatMessage] = []
for msg in state.request.chat_history[::-1]:
chat_history_budget -= len(_tokeniser.encode(msg["text"]))
Expand All @@ -70,7 +58,7 @@ def _chat_prompt_from_messages(state: RedboxState) -> Runnable:
state.request.model_dump()
| {
"messages": state.messages,
"formatted_documents": format_documents(state.documents),
"formatted_documents": format_documents(state.request.documents),
}
| _additional_variables
)
Expand All @@ -90,98 +78,10 @@ def _chat_prompt_from_messages(state: RedboxState) -> Runnable:

def build_llm_chain(
prompt_set: PromptSet,
retriever: Runnable,
llm: BaseChatModel,
) -> Runnable:
"""Builds a chain that correctly forms a text and metadata state update.
Permits both invoke and astream_events.
"""
return RunnableLambda(retriever_runnable(retriever)) | build_chat_prompt_from_messages_runnable(prompt_set) | llm


class CannedChatLLM(BaseChatModel):
"""A custom chat model that returns its text as if an LLM returned it.
Based on https://python.langchain.com/v0.2/docs/how_to/custom_chat_model/
"""

messages: list[AIMessage]

def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""Run the LLM on the given input.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
If generation stops due to a stop token, the stop token itself
SHOULD BE INCLUDED as part of the output. This is not enforced
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
message = AIMessage(content=self.messages[-1].content)

generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])

def _stream(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream the LLM on the given prompt.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
If generation stops due to a stop token, the stop token itself
SHOULD BE INCLUDED as part of the output. This is not enforced
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
for token in re_string_pattern.split(self.messages[-1].content):
chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))

if run_manager:
# This is optional in newer versions of LangChain
# The on_llm_new_token will be called automatically
run_manager.on_llm_new_token(token, chunk=chunk)

yield chunk

# Final token should be empty
chunk = ChatGenerationChunk(message=AIMessageChunk(content=""))
if run_manager:
# This is optional in newer versions of LangChain
# The on_llm_new_token will be called automatically
run_manager.on_llm_new_token(token, chunk=chunk)

yield chunk

@property
def _identifying_params(self) -> dict[str, Any]:
"""Return a dictionary of identifying parameters."""
return {
# The model name allows users to specify custom token counting
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": "CannedChatModel",
}

@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model. Used for logging purposes only."""
return "canned"
return build_chat_prompt_from_messages_runnable(prompt_set) | llm
10 changes: 1 addition & 9 deletions redbox-core/redbox/models/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,16 @@ class AISettings(BaseModel):

class RedboxQuery(BaseModel):
question: str = Field(description="The last user chat message")
s3_keys: list[str] = Field(description="List of files to process", default_factory=list)
documents: list[Document] = Field(description="List of files to process", default_factory=list)
user_uuid: UUID = Field(description="User the chain in executing for")
chat_history: list[ChainChatMessage] = Field(description="All previous messages in chat (excluding question)")
ai_settings: AISettings = Field(description="User request AI settings", default_factory=AISettings)
permitted_s3_keys: list[str] = Field(description="List of permitted files for response", default_factory=list)


class RedboxState(BaseModel):
request: RedboxQuery
documents: list[Document] = Field(default_factory=list)
messages: Annotated[list[AnyMessage], add_messages] = Field(default_factory=list)

@property
def last_message(self) -> AnyMessage:
if not self.messages:
raise ValueError("No messages in the state")
return self.messages[-1]


class PromptSet(StrEnum):
Chat = "chat"
Expand Down
10 changes: 0 additions & 10 deletions redbox-core/redbox/models/errors.py

This file was deleted.

3 changes: 0 additions & 3 deletions redbox-core/redbox/retriever/__init__.py

This file was deleted.

46 changes: 0 additions & 46 deletions redbox-core/redbox/retriever/retrievers.py
Original file line number Diff line number Diff line change
@@ -1,46 +0,0 @@
from logging import getLogger

from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever

from redbox.models.chain import RedboxState
from typing import Any

logger = getLogger(__name__)


class DjangoFileRetriever(BaseRetriever):
file_manager: Any = None

def _get_relevant_documents(
self, query: RedboxState, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]:
selected_files = set(query.request.s3_keys)
permitted_files = set(query.request.permitted_s3_keys)

if not selected_files <= permitted_files:
logger.warning(
"User has selected files they aren't permitted to access: \n"
f"{", ".join(selected_files - permitted_files)}"
)

file_names = list(selected_files & permitted_files)

files = self.file_manager.filter(original_file__in=file_names, text__isnull=False)

return [
Document(
page_content=file.text,
metadata={"uri": file.original_file.name},
)
for file in files
]


def retriever_runnable(retriever: BaseRetriever):
def _run(state: RedboxState):
state.documents = retriever.invoke(state)
return state

return _run
Empty file.
2 changes: 1 addition & 1 deletion streamlit_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def run_streamlit():
state = RedboxState(
request=RedboxQuery(
question=prompt,
s3_keys=[],
documents=[],
user_uuid=uuid4(),
chat_history=[],
ai_settings=st.session_state.ai_settings,
Expand Down

0 comments on commit 9f3811c

Please sign in to comment.