-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpdf_query_rag_llm_app.py
111 lines (98 loc) · 4.11 KB
/
pdf_query_rag_llm_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import boto3
import streamlit as st
from langchain_aws import BedrockEmbeddings
from langchain_aws.chat_models import ChatBedrock
from langchain_community.vectorstores.faiss import FAISS
from langchain_community.document_loaders.pdf import PyPDFDirectoryLoader
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.globals import set_verbose
set_verbose(False)
bedrock = boto3.client(service_name='bedrock-runtime', region_name='us-east-1')
titan_embeddings = BedrockEmbeddings(model_id="amazon.titan-embed-text-v1", client=bedrock)
# Data Preparation
def data_ingestion():
loader = PyPDFDirectoryLoader("data")
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=500)
docs = text_splitter.split_documents(documents)
return docs
# Vector Store Setup
def setup_vector_store(documents):
vector_store = FAISS.from_documents(
documents,
titan_embeddings,
)
vector_store.save_local("faiss_index")
# LLM Setup
def load_llm():
llm = ChatBedrock(model_id="anthropic.claude-3-5-sonnet-20240620-v1:0", client=bedrock, model_kwargs={"max_tokens": 2048})
return llm
# LLM Guidelines
prompt_template = """Use the following pieces of context to answer the question at the end. Follow these rules:
1. If the answer is not within the context knowledge, state that you do not know, and do not fabricate an answer.
2. If you find the answer, create a detailed, and concise response to the question. Aim for a summary of max 200 words.
3. Do not add extra information not within the context.
{context}
Question: {question}
Helpful Answer:"""
# Prompt Template
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
# Create QA Chain
def get_result(llm, vector_store, query):
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vector_store.as_retriever(
search_type="similarity", search_kwargs={"k": 3}
),
chain_type_kwargs={"prompt": prompt},
return_source_documents=True,
)
# Apply LLM
result = qa_chain.invoke(query)
return result['result']
# Streamlit Frontend UI Section
def streamlit_ui():
st.set_page_config("PDF Query RAG LLM Application")
st.markdown("""
<style>
.reportview-container {
margin-top: -2em;
}
#MainMenu {visibility: hidden;}
.stDeployButton {display:none;}
footer {visibility: hidden;}
#stDecoration {display:none;}
</style>
""", unsafe_allow_html=True)
st.header("PDF Query with Generative AI")
user_question = st.text_input(r":gray[$\textsf{\normalsize Ask me anything about your PDF collection.}$]")
left_column, middleleft_column, middleright_column, right_column = st.columns(4, gap="small")
if left_column.button("Generate Response", key="submit_question") or user_question:
# first check if the vector store exists
if not os.path.exists("faiss_index"):
st.error("Please create the vector store first from the sidebar.")
return
if not user_question:
st.error("Please enter a question.")
return
with st.spinner("Generating... this may take a minute..."):
faiss_index = FAISS.load_local("faiss_index", embeddings=titan_embeddings,
allow_dangerous_deserialization=True)
llm = load_llm()
st.write(get_result(llm, faiss_index, user_question))
st.success("Generated")
if middleleft_column:
st.empty()
if middleright_column:
st.empty()
if right_column.button("New Data Update", key="update_vector_store"):
with st.spinner("Updating... this may take a few minutes as we go through your PDF collection..."):
docs = data_ingestion()
setup_vector_store(docs)
st.success("Updated")
if __name__ == "__main__":
streamlit_ui()