-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathapp.py
133 lines (117 loc) · 4.75 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import streamlit as st
import os
from dotenv import load_dotenv
from langchain_community.chat_models.gigachat import GigaChat
from langchain_community.document_loaders import TextLoader, PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings.gigachat import GigaChatEmbeddings
from langchain.chains import RetrievalQA
from chromadb.config import Settings
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
if 'file' not in st.session_state:
st.session_state.file = False
def set_file():
st.session_state.file = True
st.session_state.messages = []
def rerun():
os.remove(f'src.{fextension}')
st.rerun()
KEY=st.secrets['GIGA_KEY']
@st.cache_resource
def load_pipeline(uploaded_file):
if uploaded_file is not None:
with open(f'src.{fextension}', 'wb') as f:
f.write(uploaded_file.getbuffer())
st.session_state.file = True
with st.spinner('Splitting and getting embeddings...'):
if fextension == 'txt':
loader = TextLoader("src.txt")
elif fextension == 'pdf':
loader = PyPDFLoader("src.pdf")
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=150,
)
llm = GigaChat(credentials=KEY, verify_ssl_certs=False)
documents = text_splitter.split_documents(documents)
embeddings = GigaChatEmbeddings(credentials=KEY, verify_ssl_certs=False)
db = Chroma.from_documents(
documents,
embeddings,
client_settings=Settings(anonymized_telemetry=False))
qa_chain = RetrievalQA.from_chain_type(llm, retriever=db.as_retriever())
return qa_chain
# @st.cache_resource
# def load_txt_pipeline(uploaded_file):
# if uploaded_file is not None:
# with open('src.txt', 'wb') as f:
# f.write(uploaded_file.getbuffer())
# st.session_state.file = True
# with st.spinner('Splitting and getting embeddings...'):
# loader = TextLoader("src.txt")
# documents = loader.load()
# text_splitter = RecursiveCharacterTextSplitter(
# chunk_size=500,
# chunk_overlap=150,
# )
# llm = GigaChat(credentials=KEY, verify_ssl_certs=False)
# documents = text_splitter.split_documents(documents)
# embeddings = GigaChatEmbeddings(credentials=KEY, verify_ssl_certs=False)
# db = Chroma.from_documents(
# documents,
# embeddings,
# client_settings=Settings(anonymized_telemetry=False))
# qa_chain = RetrievalQA.from_chain_type(llm, retriever=db.as_retriever())
# return qa_chain
# @st.cache_resource
# def load_pdf_pipeline(uploaded_file):
# if uploaded_file is not None:
# with open('src.pdf', 'wb') as f:
# f.write(uploaded_file.getbuffer())
# st.session_state.file = True
# with st.spinner('Splitting and getting embeddings...'):
# loader = PyPDFLoader("src.pdf")
# documents = loader.load()
# text_splitter = RecursiveCharacterTextSplitter(
# chunk_size=500,
# chunk_overlap=150,
# )
# llm = GigaChat(credentials=KEY, verify_ssl_certs=False)
# documents = text_splitter.split_documents(documents)
# embeddings = GigaChatEmbeddings(credentials=KEY, verify_ssl_certs=False)
# db = Chroma.from_documents(
# documents,
# embeddings,
# client_settings=Settings(anonymized_telemetry=False))
# qa_chain = RetrievalQA.from_chain_type(llm, retriever=db.as_retriever())
# return qa_chain
uploaded_file = st.sidebar.file_uploader(
'Upload file',
type=['txt', 'pdf'],
accept_multiple_files=False,
on_change=set_file
)
if st.session_state.file:
fname = uploaded_file.name
fextension = fname[fname.rfind('.')+1:]
qa_chain = load_pipeline(uploaded_file)
# if fextension == 'txt':
# qa_chain = load_txt_pipeline(uploaded_file)
# elif fextension == 'pdf':
# qa_chain = load_pdf_pipeline(uploaded_file)
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("What is up?"):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
response = qa_chain({"query": prompt})
with st.chat_message("assistant"):
st.write(response['result'])
st.session_state.messages.append({"role": "assistant", "content": response['result']})