Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: adding llm prompts to the UI #232

Merged
merged 7 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def handle_run(record):
"mode": mode,
"sessionId": session_id,
"userId": user_id,
"prompts": [prompt_template],
}
if files:
metadata["files"] = files
Expand All @@ -131,6 +132,7 @@ def handle_run(record):
"sessionId": session_id,
"type": "text",
"content": mlm_response,
"metadata": metadata,
}

send_to_client(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
QA_PROMPT,
CONDENSE_QUESTION_PROMPT,
)
from typing import Dict, List, Any

from genai_core.langchain import WorkspaceRetriever, DynamoDBChatMessageHistory
from genai_core.types import ChatbotMode
Expand All @@ -20,6 +21,16 @@ class Mode(Enum):
CHAIN = "chain"


class LLMStartHandler(BaseCallbackHandler):
prompts = []

def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> Any:
logger.info(prompts)
self.prompts.append(prompts)


class ModelAdapter:
def __init__(
self, session_id, user_id, mode=ChatbotMode.CHAIN.value, model_kwargs={}
Expand All @@ -29,7 +40,7 @@ def __init__(
self._mode = mode
self.model_kwargs = model_kwargs

self.callback_handler = BaseCallbackHandler()
self.callback_handler = LLMStartHandler()
self.__bind_callbacks()

self.chat_history = self.get_chat_history()
Expand Down Expand Up @@ -73,7 +84,7 @@ def get_prompt(self):
{chat_history}

Question: {input}"""

return PromptTemplate.from_template(template)

def get_condense_question_prompt(self):
Expand All @@ -86,6 +97,8 @@ def run_with_chain(self, user_prompt, workspace_id=None):
if not self.llm:
raise ValueError("llm must be set")

self.callback_handler.prompts = []

if workspace_id:
conversation = ConversationalRetrievalChain.from_llm(
self.llm,
Expand Down Expand Up @@ -116,6 +129,7 @@ def run_with_chain(self, user_prompt, workspace_id=None):
"userId": self.user_id,
"workspaceId": workspace_id,
"documents": documents,
"prompts": self.callback_handler.prompts,
}

self.chat_history.add_metadata(metadata)
Expand Down Expand Up @@ -144,6 +158,7 @@ def run_with_chain(self, user_prompt, workspace_id=None):
"sessionId": self.session_id,
"userId": self.user_id,
"documents": [],
"prompts": self.callback_handler.prompts,
}

self.chat_history.add_metadata(metadata)
Expand Down
Loading