-
Notifications
You must be signed in to change notification settings - Fork 1
/
app.py
177 lines (149 loc) · 5.57 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import streamlit as st
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain_cohere import ChatCohere
from langchain_community.chat_message_histories.streamlit import (
StreamlitChatMessageHistory,
)
from langchain_openai import ChatOpenAI
from calback_handler import PrintRetrievalHandler, StreamHandler
from chat_profile import ChatProfileRoleEnum
from document_retriever import configure_retriever
from llm_provider import LLMProviderEnum
# Constants
GPT_LLM_MODEL = "gpt-3.5-turbo"
COMMAND_R_LLM_MODEL = "command-r"
# Properties
uploaded_files = []
api_key = ""
result_retriever = None
chain = None
llm = None
model_name = ""
# Set up sidebar
if "sidebar_state" not in st.session_state:
st.session_state.sidebar_state = "expanded"
# Streamlit app configuration
st.set_page_config(
page_title="InkChatGPT: Chat with Documents",
page_icon="📚",
initial_sidebar_state=st.session_state.sidebar_state,
menu_items={
"Get Help": "https://x.com/vinhnx",
"Report a bug": "https://github.com/vinhnx/InkChatGPT/issues",
"About": """InkChatGPT is a simple Retrieval Augmented Generation (RAG) application that allows users to upload PDF documents and engage in a conversational Q&A, with a language model (LLM) based on the content of those documents.
GitHub: https://github.com/vinhnx/InkChatGPT""",
},
)
with st.sidebar:
with st.container():
col1, col2 = st.columns([0.2, 0.8])
with col1:
st.image(
"./assets/app_icon.png",
use_column_width="always",
output_format="PNG",
)
with col2:
st.header(":books: InkChatGPT")
# Model
selected_model = st.selectbox(
"Select a model",
options=[
LLMProviderEnum.OPEN_AI.value,
LLMProviderEnum.COHERE.value,
],
index=None,
placeholder="Select a model...",
)
if selected_model:
api_key = st.text_input(f"{selected_model} API Key", type="password")
if selected_model == LLMProviderEnum.OPEN_AI:
model_name = GPT_LLM_MODEL
elif selected_model == LLMProviderEnum.COHERE:
model_name = COMMAND_R_LLM_MODEL
msgs = StreamlitChatMessageHistory()
if len(msgs.messages) == 0 or st.button("Clear message history"):
msgs.clear()
msgs.add_ai_message("""
Hi, your uploaded document(s) had been analyzed.
Feel free to ask me any questions. For example: you can start by asking me something like:
`What is this context about?`
`Help me summarize this!`
""")
if api_key:
# Documents
uploaded_files = st.file_uploader(
label="Select files",
type=["pdf", "txt", "docx"],
accept_multiple_files=True,
disabled=(not selected_model),
)
if api_key and not uploaded_files:
st.info("🌟 You can upload some documents to get started")
# Check if a model is selected
if not selected_model:
st.info(
"📺 Please select a model first, open the `Settings` tab from side bar menu to get started"
)
# Check if API key is provided
if selected_model and len(api_key.strip()) == 0:
st.warning(
f"🔑 API key for {selected_model} is missing or invalid. Please provide a valid API key."
)
# Process uploaded files
if uploaded_files:
result_retriever = configure_retriever(uploaded_files, cohere_api_key=api_key)
if result_retriever is not None:
memory = ConversationBufferMemory(
memory_key="chat_history",
chat_memory=msgs,
return_messages=True,
)
if selected_model == LLMProviderEnum.OPEN_AI:
llm = ChatOpenAI(
model=model_name,
api_key=api_key,
temperature=0,
streaming=True,
)
elif selected_model == LLMProviderEnum.COHERE:
llm = ChatCohere(
model=model_name,
temperature=0.3,
streaming=True,
cohere_api_key=api_key,
)
if llm is None:
st.error(
"Failed to initialize the language model. Please check your configuration."
)
# Create the ConversationalRetrievalChain instance using the llm instance
chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=result_retriever,
memory=memory,
verbose=True,
max_tokens_limit=4000,
)
avatars = {
ChatProfileRoleEnum.HUMAN.value: "user",
ChatProfileRoleEnum.AI.value: "assistant",
}
for msg in msgs.messages:
st.chat_message(avatars[msg.type]).write(msg.content)
# Get user input and generate response
if user_query := st.chat_input(
placeholder="Ask me anything!",
disabled=(not uploaded_files),
):
st.chat_message("user").write(user_query)
with st.chat_message("assistant"):
retrieval_handler = PrintRetrievalHandler(st.empty())
stream_handler = StreamHandler(st.empty())
response = chain.run(
user_query,
callbacks=[retrieval_handler, stream_handler],
)
if selected_model and model_name:
st.sidebar.caption(f"🪄 Using `{model_name}` model")